Repository: bosun-ai/swiftide Branch: master Commit: 849ebf9a2a7e Files: 350 Total size: 1.9 MB Directory structure: gitextract_d7f6_hv1/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ ├── dependabot.yml │ └── workflows/ │ ├── bench.yml │ ├── coverage.yml │ ├── discord.yml │ ├── lint.yml │ ├── pr.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── .markdownlint.yaml ├── AGENTS.md ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE ├── README.md ├── benchmarks/ │ ├── Cargo.toml │ ├── fileloader.rs │ ├── local_pipeline.rs │ ├── node_cache_comparison.rs │ └── output.txt ├── cliff.toml ├── deny.toml ├── examples/ │ ├── Cargo.toml │ ├── agent_can_fail_custom_schema.rs │ ├── agents_mcp_tools.rs │ ├── agents_resume.rs │ ├── agents_with_human_in_the_loop.rs │ ├── aws_bedrock.rs │ ├── aws_bedrock_agent.rs │ ├── dashscope.rs │ ├── describe_image.rs │ ├── fastembed.rs │ ├── fluvio.rs │ ├── hello_agents.rs │ ├── hybrid_search.rs │ ├── index_codebase.rs │ ├── index_codebase_reduced_context.rs │ ├── index_groq.rs │ ├── index_into_redis.rs │ ├── index_markdown_lots_of_metadata.rs │ ├── index_md_into_pgvector.rs │ ├── index_ollama.rs │ ├── kafka.rs │ ├── lancedb.rs │ ├── langfuse.rs │ ├── query_pipeline.rs │ ├── reranking.rs │ ├── responses_api.rs │ ├── responses_api_reasoning.rs │ ├── scraping_index_to_markdown.rs │ ├── stop_with_args_custom_schema.rs │ ├── store_multiple_vectors.rs │ ├── streaming_agents.rs │ ├── structured_prompt.rs │ ├── tasks.rs │ ├── tool_custom_schema.rs │ └── usage_metrics.rs ├── release-plz.toml ├── renovate.json ├── rust-toolchain.toml ├── rustfmt.toml ├── swiftide/ │ ├── Cargo.toml │ ├── build.rs │ ├── src/ │ │ ├── lib.rs │ │ └── test_utils.rs │ └── tests/ │ ├── dyn_traits.rs │ ├── indexing_pipeline.rs │ ├── lancedb.rs │ ├── pgvector.rs │ ├── query_pipeline.rs │ └── sparse_embeddings_and_hybrid_search.rs ├── swiftide-agents/ │ ├── Cargo.toml │ └── src/ │ ├── agent.rs │ ├── default_context.rs │ ├── errors.rs │ ├── hooks.rs │ ├── lib.rs │ ├── snapshots/ │ │ ├── swiftide_agents__system_prompt__tests__customization.snap │ │ └── swiftide_agents__system_prompt__tests__to_prompt.snap │ ├── state.rs │ ├── system_prompt.rs │ ├── system_prompt_template.md │ ├── tasks/ │ │ ├── closures.rs │ │ ├── errors.rs │ │ ├── impls.rs │ │ ├── mod.rs │ │ ├── node.rs │ │ ├── task.rs │ │ └── transition.rs │ ├── test_utils.rs │ ├── tools/ │ │ ├── arg_preprocessor.rs │ │ ├── control.rs │ │ ├── local_executor.rs │ │ ├── mcp.rs │ │ └── mod.rs │ └── util.rs ├── swiftide-core/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ ├── agent_traits.rs │ ├── chat_completion/ │ │ ├── chat_completion_request.rs │ │ ├── chat_completion_response.rs │ │ ├── chat_message.rs │ │ ├── errors.rs │ │ ├── mod.rs │ │ ├── tool_schema.rs │ │ ├── tools.rs │ │ └── traits.rs │ ├── document.rs │ ├── indexing_decorators.rs │ ├── indexing_defaults.rs │ ├── indexing_stream.rs │ ├── indexing_traits.rs │ ├── lib.rs │ ├── metadata.rs │ ├── metrics.rs │ ├── node.rs │ ├── prelude.rs │ ├── prompt.rs │ ├── query.rs │ ├── query_evaluation.rs │ ├── query_stream.rs │ ├── query_traits.rs │ ├── search_strategies/ │ │ ├── custom_strategy.rs │ │ ├── hybrid_search.rs │ │ ├── mod.rs │ │ └── similarity_single_embedding.rs │ ├── statistics.rs │ ├── stream_backoff.rs │ ├── test_utils.rs │ ├── token_estimation.rs │ ├── type_aliases.rs │ └── util.rs ├── swiftide-indexing/ │ ├── Cargo.toml │ └── src/ │ ├── lib.rs │ ├── loaders/ │ │ ├── file_loader.rs │ │ └── mod.rs │ ├── persist/ │ │ ├── memory_storage.rs │ │ └── mod.rs │ ├── pipeline.rs │ └── transformers/ │ ├── chunk_markdown.rs │ ├── chunk_text.rs │ ├── embed.rs │ ├── metadata_keywords.rs │ ├── metadata_qa_text.rs │ ├── metadata_summary.rs │ ├── metadata_title.rs │ ├── mod.rs │ ├── prompts/ │ │ ├── metadata_keywords.prompt.md │ │ ├── metadata_qa_text.prompt.md │ │ ├── metadata_summary.prompt.md │ │ └── metadata_title.prompt.md │ ├── snapshots/ │ │ ├── swiftide_indexing__transformers__compress_code_outline__test__compress_code_template.snap │ │ ├── swiftide_indexing__transformers__metadata_keywords__test__template.snap │ │ ├── swiftide_indexing__transformers__metadata_qa_code__test__template.snap │ │ ├── swiftide_indexing__transformers__metadata_qa_code__test__template_with_outline.snap │ │ ├── swiftide_indexing__transformers__metadata_qa_text__test__template.snap │ │ ├── swiftide_indexing__transformers__metadata_summary__test__template.snap │ │ └── swiftide_indexing__transformers__metadata_title__test__template.snap │ └── sparse_embed.rs ├── swiftide-integrations/ │ ├── Cargo.toml │ └── src/ │ ├── anthropic/ │ │ ├── chat_completion.rs │ │ ├── mod.rs │ │ ├── simple_prompt.rs │ │ └── tool_schema.rs │ ├── aws_bedrock_v2/ │ │ ├── chat_completion.rs │ │ ├── mod.rs │ │ ├── simple_prompt.rs │ │ ├── structured_prompt.rs │ │ ├── test_utils.rs │ │ └── tool_schema.rs │ ├── dashscope/ │ │ ├── config.rs │ │ └── mod.rs │ ├── duckdb/ │ │ ├── extensions.sql │ │ ├── hybrid_query.sql │ │ ├── mod.rs │ │ ├── node_cache.rs │ │ ├── persist.rs │ │ ├── retrieve.rs │ │ ├── schema.sql │ │ └── upsert.sql │ ├── fastembed/ │ │ ├── embedding_model.rs │ │ ├── mod.rs │ │ ├── rerank.rs │ │ └── sparse_embedding_model.rs │ ├── fluvio/ │ │ ├── loader.rs │ │ └── mod.rs │ ├── gemini/ │ │ ├── config.rs │ │ └── mod.rs │ ├── groq/ │ │ ├── config.rs │ │ └── mod.rs │ ├── kafka/ │ │ ├── loader.rs │ │ ├── mod.rs │ │ └── persist.rs │ ├── lancedb/ │ │ ├── connection_pool.rs │ │ ├── mod.rs │ │ ├── persist.rs │ │ └── retrieve.rs │ ├── lib.rs │ ├── ollama/ │ │ ├── config.rs │ │ └── mod.rs │ ├── open_router/ │ │ ├── config.rs │ │ └── mod.rs │ ├── openai/ │ │ ├── chat_completion.rs │ │ ├── embed.rs │ │ ├── mod.rs │ │ ├── responses_api.rs │ │ ├── simple_prompt.rs │ │ ├── structured_prompt.rs │ │ └── tool_schema.rs │ ├── parquet/ │ │ ├── loader.rs │ │ ├── mod.rs │ │ └── test.parquet │ ├── pgvector/ │ │ ├── fixtures.rs │ │ ├── mod.rs │ │ ├── persist.rs │ │ ├── pgv_table_types.rs │ │ └── retrieve.rs │ ├── qdrant/ │ │ ├── indexing_node.rs │ │ ├── mod.rs │ │ ├── persist.rs │ │ └── retrieve.rs │ ├── redb/ │ │ ├── mod.rs │ │ └── node_cache.rs │ ├── redis/ │ │ ├── message_history.rs │ │ ├── mod.rs │ │ ├── node_cache.rs │ │ └── persist.rs │ ├── scraping/ │ │ ├── html_to_markdown_transformer.rs │ │ ├── loader.rs │ │ └── mod.rs │ ├── tiktoken/ │ │ └── mod.rs │ └── treesitter/ │ ├── chunk_code.rs │ ├── code_tree.rs │ ├── compress_code_outline.rs │ ├── metadata_qa_code.rs │ ├── metadata_refs_defs_code.rs │ ├── mod.rs │ ├── outline_code_tree_sitter.rs │ ├── outliner.rs │ ├── prompts/ │ │ ├── compress_code_outline.prompt.md │ │ └── metadata_qa_code.prompt.md │ ├── queries.rs │ ├── snapshots/ │ │ ├── swiftide_integrations__treesitter__compress_code_outline__test__compress_code_template.snap │ │ ├── swiftide_integrations__treesitter__metadata_qa_code__test__default_prompt.snap │ │ └── swiftide_integrations__treesitter__metadata_qa_code__test__template_with_outline.snap │ ├── splitter.rs │ └── supported_languages.rs ├── swiftide-langfuse/ │ ├── Cargo.toml │ ├── src/ │ │ ├── apis/ │ │ │ ├── configuration.rs │ │ │ ├── ingestion_api.rs │ │ │ └── mod.rs │ │ ├── langfuse_batch_manager.rs │ │ ├── lib.rs │ │ ├── models/ │ │ │ ├── create_event_body.rs │ │ │ ├── create_generation_body.rs │ │ │ ├── create_score_value.rs │ │ │ ├── create_span_body.rs │ │ │ ├── ingestion_batch_request.rs │ │ │ ├── ingestion_error.rs │ │ │ ├── ingestion_event.rs │ │ │ ├── ingestion_event_one_of.rs │ │ │ ├── ingestion_event_one_of_1.rs │ │ │ ├── ingestion_event_one_of_2.rs │ │ │ ├── ingestion_event_one_of_3.rs │ │ │ ├── ingestion_event_one_of_4.rs │ │ │ ├── ingestion_event_one_of_5.rs │ │ │ ├── ingestion_event_one_of_6.rs │ │ │ ├── ingestion_event_one_of_7.rs │ │ │ ├── ingestion_event_one_of_8.rs │ │ │ ├── ingestion_event_one_of_9.rs │ │ │ ├── ingestion_response.rs │ │ │ ├── ingestion_success.rs │ │ │ ├── ingestion_usage.rs │ │ │ ├── map_value.rs │ │ │ ├── mod.rs │ │ │ ├── model_usage_unit.rs │ │ │ ├── observation_body.rs │ │ │ ├── observation_level.rs │ │ │ ├── observation_type.rs │ │ │ ├── open_ai_completion_usage_schema.rs │ │ │ ├── open_ai_response_usage_schema.rs │ │ │ ├── open_ai_usage.rs │ │ │ ├── optional_observation_body.rs │ │ │ ├── score_body.rs │ │ │ ├── score_data_type.rs │ │ │ ├── sdk_log_body.rs │ │ │ ├── trace_body.rs │ │ │ ├── update_generation_body.rs │ │ │ ├── update_span_body.rs │ │ │ ├── usage.rs │ │ │ └── usage_details.rs │ │ └── tracing_layer.rs │ └── tests/ │ ├── full_flow.rs │ └── snapshots/ │ └── full_flow__integration_tracing_layer_sends_to_langfuse.snap ├── swiftide-macros/ │ ├── Cargo.toml │ ├── src/ │ │ ├── indexing_transformer.rs │ │ ├── lib.rs │ │ ├── test_utils.rs │ │ └── tool/ │ │ ├── args.rs │ │ ├── mod.rs │ │ ├── snapshots/ │ │ │ ├── swiftide_macros__tool__tests__simple_tool.snap │ │ │ ├── swiftide_macros__tool__tests__snapshot_derive.snap │ │ │ ├── swiftide_macros__tool__tests__snapshot_derive_with_args.snap │ │ │ ├── swiftide_macros__tool__tests__snapshot_derive_with_generics.snap │ │ │ ├── swiftide_macros__tool__tests__snapshot_derive_with_lifetime.snap │ │ │ ├── swiftide_macros__tool__tests__snapshot_derive_with_option.snap │ │ │ ├── swiftide_macros__tool__tests__snapshot_multiple_args.snap │ │ │ ├── swiftide_macros__tool__tests__snapshot_single_arg.snap │ │ │ └── swiftide_macros__tool__tests__snapshot_single_arg_option.snap │ │ ├── tool_spec.rs │ │ └── wrapped.rs │ └── tests/ │ ├── tool/ │ │ ├── tool_derive_missing_description.rs │ │ ├── tool_derive_missing_description.stderr │ │ ├── tool_derive_pass.rs │ │ ├── tool_derive_vec_argument_pass.rs │ │ ├── tool_missing_arg_fail.rs │ │ ├── tool_missing_arg_fail.stderr │ │ ├── tool_missing_parameter_fail.rs │ │ ├── tool_missing_parameter_fail.stderr │ │ ├── tool_multiple_arguments_pass.rs │ │ ├── tool_no_argument_pass.rs │ │ ├── tool_object_argument_pass.rs │ │ └── tool_single_argument_pass.rs │ └── tool.rs ├── swiftide-query/ │ ├── Cargo.toml │ └── src/ │ ├── answers/ │ │ ├── mod.rs │ │ ├── simple.rs │ │ └── snapshots/ │ │ ├── swiftide_query__answers__simple__test__custom_document_template.snap │ │ ├── swiftide_query__answers__simple__test__default_prompt.snap │ │ └── swiftide_query__answers__simple__test__uses_current_if_present.snap │ ├── evaluators/ │ │ ├── mod.rs │ │ └── ragas.rs │ ├── lib.rs │ ├── query/ │ │ ├── mod.rs │ │ └── pipeline.rs │ ├── query_transformers/ │ │ ├── embed.rs │ │ ├── generate_subquestions.rs │ │ ├── mod.rs │ │ ├── snapshots/ │ │ │ └── swiftide_query__query_transformers__generate_subquestions__test__default_prompt.snap │ │ └── sparse_embed.rs │ └── response_transformers/ │ ├── mod.rs │ ├── snapshots/ │ │ └── swiftide_query__response_transformers__summary__test__default_prompt.snap │ └── summary.rs ├── swiftide-test-utils/ │ ├── Cargo.toml │ └── src/ │ ├── lib.rs │ └── test_utils.rs └── typos.toml ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: "" labels: "" assignees: "" --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** A clear and minimal example to reproduce the bug. **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/dependabot.yml ================================================ # To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file version: 2 updates: - package-ecosystem: "cargo" # See documentation for possible values directory: "/" # Location of package manifests schedule: interval: "daily" groups: development: dependency-type: "development" tree-sitter: patterns: - "tree-sitter*" aws: patterns: - "aws*" minor: update-types: - "minor" - "patch" - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" ================================================ FILE: .github/workflows/bench.yml ================================================ name: Bench on: push: branches: - master permissions: contents: write deployments: write jobs: benchmark: name: Benchmark runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - name: Install Protoc uses: arduino/setup-protoc@v3 - name: Run benchmark run: cargo bench -p benchmarks -- --output-format bencher | tee benchmarks/output.txt - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: name: Rust Benchmark tool: "cargo" output-file-path: benchmarks/output.txt github-token: ${{ github.token }} auto-push: true # Show alert with commit comment on detecting possible performance regression alert-threshold: "200%" comment-on-alert: true fail-on-alert: true alert-comment-cc-users: "@timonv" ================================================ FILE: .github/workflows/coverage.yml ================================================ name: Coverage on: pull_request: push: branches: - master concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-coverage cancel-in-progress: true env: RUSTFLAGS: "-Dwarnings -Clink-arg=-fuse-ld=lld" jobs: test: name: coverage runs-on: ubuntu-latest steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main - name: Checkout repository uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@nightly with: components: llvm-tools-preview - name: Install Protoc uses: arduino/setup-protoc@v3 - name: Install cargo-llvm-cov uses: taiki-e/install-action@v2 with: tool: cargo-llvm-cov - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y lld libcurl4-openssl-dev - name: Generate code coverage run: | cargo llvm-cov --tests -j 2 --all-features --lcov --output-path lcov.info - name: Coveralls uses: coverallsapp/github-action@v2 ================================================ FILE: .github/workflows/discord.yml ================================================ on: release: types: [published] jobs: github-releases-to-discord: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v6 - name: Github Releases To Discord uses: SethCohen/github-releases-to-discord@v1.19.0 with: webhook_url: ${{ secrets.DISCORD_WEBHOOK_URL }} color: "2105893" username: "Release Changelog" avatar_url: "https://cdn.discordapp.com/avatars/487431320314576937/bd64361e4ba6313d561d54e78c9e7171.png" footer_title: "Changelog" footer_icon_url: "https://cdn.discordapp.com/avatars/487431320314576937/bd64361e4ba6313d561d54e78c9e7171.png" footer_timestamp: true ================================================ FILE: .github/workflows/lint.yml ================================================ name: CI on: pull_request: merge_group: push: branches: - master concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-lint env: CARGO_TERM_COLOR: always jobs: lint: name: Lint runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: components: clippy - uses: r7kamura/rust-problem-matchers@v1 - name: Install Protoc uses: arduino/setup-protoc@v3 - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y libcurl4-openssl-dev - name: Check typos uses: crate-ci/typos@master # - name: Lint dependencies # uses: EmbarkStudios/cargo-deny-action@v2 - name: clippy run: cargo clippy --all-targets --all-features --workspace env: RUSTFLAGS: "-Dwarnings" lint-formatting: name: Lint formatting runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@nightly with: components: rustfmt - uses: r7kamura/rust-problem-matchers@v1 - name: "Rustfmt" run: cargo +nightly fmt --all -- --check env: RUSTFLAGS: "-Dwarnings" hack: name: Cargo Hack runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt - uses: r7kamura/rust-problem-matchers@v1 - name: Install Protoc uses: arduino/setup-protoc@v3 - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y libcurl4-openssl-dev - name: Install cargo-hack uses: taiki-e/install-action@v2 with: tool: cargo-hack - name: Check features with Cargo Hack run: cargo hack check --each-feature --no-dev-deps ================================================ FILE: .github/workflows/pr.yml ================================================ name: Check Pull Requests on: pull_request_target: types: - opened - edited - synchronize - labeled - unlabeled merge_group: permissions: pull-requests: write jobs: check-title: runs-on: ubuntu-latest steps: - name: Check PR title if: github.event_name == 'pull_request_target' uses: amannn/action-semantic-pull-request@v6 id: check_pr_title env: GITHUB_TOKEN: ${{ github.token }} # Add comment indicating we require pull request titles to follow conventional commits specification - uses: marocchino/sticky-pull-request-comment@v2 if: always() && (steps.check_pr_title.outputs.error_message != null) with: header: pr-title-lint-error message: | Thank you for opening this pull request! We require pull request titles to follow the [Conventional Commits specification](https://www.conventionalcommits.org/en/v1.0.0/) and it looks like your proposed title needs to be adjusted. Details: > ${{ steps.check_pr_title.outputs.error_message }} # Delete a previous comment when the issue has been resolved - if: ${{ steps.check_pr_title.outputs.error_message == null }} uses: marocchino/sticky-pull-request-comment@v2 with: header: pr-title-lint-error delete: true check-breaking-change-label: runs-on: ubuntu-latest env: # use an environment variable to pass untrusted input to the script # see https://securitylab.github.com/research/github-actions-untrusted-input/ PR_TITLE: ${{ github.event.pull_request.title }} steps: - name: Check breaking change label id: check_breaking_change run: | pattern='^(build|chore|ci|docs|feat|fix|perf|refactor|revert|style|test)(\(\w+\))?!:' # Check if pattern matches if echo "${PR_TITLE}" | grep -qE "$pattern"; then echo "breaking_change=true" >> "$GITHUB_OUTPUT" else echo "breaking_change=false" >> "$GITHUB_OUTPUT" fi - name: Add label if: steps.check_breaking_change.outputs.breaking_change == 'true' uses: actions/github-script@v8 with: github-token: ${{ github.token }} script: | github.rest.issues.addLabels({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, labels: ['breaking change'] }) do-not-merge: if: ${{ contains(github.event.*.labels.*.name, 'do not merge') }} name: Prevent Merging runs-on: ubuntu-latest steps: - name: Check for label run: | echo "Pull request is labeled as 'do not merge'" echo "This workflow fails so that the pull request cannot be merged" exit 1 ================================================ FILE: .github/workflows/release.yml ================================================ name: Release permissions: pull-requests: write contents: write on: push: branches: - master jobs: release-swiftide: name: Crates.io runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 with: fetch-depth: 0 token: ${{ secrets.RELEASE_PLZ_TOKEN }} - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable - name: Install Protoc uses: arduino/setup-protoc@v3 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Run release-plz uses: MarcoIeni/release-plz-action@v0.5 env: GITHUB_TOKEN: ${{ secrets.RELEASE_PLZ_TOKEN }} CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} GITHUB_REPO: ${{ github.repository }} ================================================ FILE: .github/workflows/test.yml ================================================ name: CI on: pull_request: merge_group: push: branches: - master concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-test env: CARGO_TERM_COLOR: always RUSTFLAGS: "-Dwarnings -Clink-arg=-fuse-ld=lld" jobs: test: name: Test runs-on: ubuntu-latest steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - name: Install Protoc uses: arduino/setup-protoc@v3 - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y lld libcurl4-openssl-dev - name: "Test" run: cargo test -j 2 --tests --all-features --no-fail-fast docs: name: Docs runs-on: ubuntu-latest steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - name: Install Protoc uses: arduino/setup-protoc@v3 - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y lld libcurl4-openssl-dev - name: "Test" run: cargo test --doc --all-features --no-fail-fast ================================================ FILE: .gitignore ================================================ # Generated by Cargo # will have compiled files and executables debug/ target/ # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html # These are backup files generated by rustfmt **/*.rs.bk # MSVC Windows builds of rustc generate these, which store debugging information *.pdb # Added by cargo target tmp .env .env*.local **/.fastembed_cache .idea/ .history.cargo ================================================ FILE: .markdownlint.yaml ================================================ # configuration for https://github.com/DavidAnson/markdownlint first-line-heading: false no-inline-html: false line-length: false # to support repeated headers in the changelog no-duplicate-heading: false ================================================ FILE: AGENTS.md ================================================ # Repository Guidelines ## Project Structure & Module Organization Swiftide is a Rust workspace driven by the library in `swiftide/`, with supporting crates such as `swiftide-core/` for shared primitives, `swiftide-agents/` for agent orchestration, `swiftide-indexing/` and `swiftide-query/` for pipeline flows, and `swiftide-integrations/` for external connectors. Shared fixtures live in `swiftide-test-utils/`, while `examples/` hosts runnable demos and `benchmarks/` tracks performance scenarios. Static assets (logos and diagrams) are under `images/`. ## Build, Test, and Development Commands - `cargo check --workspace --all-features` quickly verifies the entire workspace compiles with all feature flags enabled. - `cargo build --workspace --all-features` compiles every crate and surfaces feature-gating issues early. - `cargo check -p swiftide-agents` is a fast way to probe agent changes before touching the rest of the workspace. - `cargo +nightly fmt --all` applies the repo `rustfmt.toml` (comment wrapping requires nightly); use `cargo +nightly fmt --all -- --check` to mirror CI formatting validation. - `cargo clippy --workspace --all-targets --all-features -- -D warnings` mirrors the main lint job and keeps us aligned with the pedantic lint profile baked into `Cargo.toml`. - `cargo test -j 2 --tests --all-features --no-fail-fast` mirrors the main CI test job for unit and integration tests. - `cargo test --doc --all-features --no-fail-fast` mirrors the docs test job in CI. - `cargo hack check --each-feature --no-dev-deps` mirrors the Cargo Hack feature-matrix check run in CI. - `typos` mirrors the spelling check run in CI. - `cargo test --workspace` is still useful locally when you want a broader default test sweep; use `RUST_LOG=info` if you need verbose diagnostics. - Snapshot updates flow through `cargo insta review` after tests rewrite `.snap` files. ## Coding Style & Naming Conventions Follow Rust 2024 idioms with four-space indentation. Public APIs should embrace builder patterns and the naming guidance from the Rust API Guidelines: `snake_case` for functions, `UpperCamelCase` for types, and `SCREAMING_SNAKE_CASE` constants. Avoid `unsafe` blocks—`Cargo.toml` forbids them at the workspace level. Keep comments concise so `wrap_comments = true` can format them within 100 columns. ## Testing Guidelines Prefer focused crate runs such as `cargo test -p swiftide-integrations` when iterating, and opt into `-- --ignored` for heavier scenarios. Integration tests rely on `testcontainers`, so ensure Docker is available; keep fixtures inside `swiftide-test-utils/` to reuse container helpers. For `insta` snapshots, commit reviewed `.snap.new` diffs only after `cargo insta review` removes pending files. ## Commit & Pull Request Guidelines Commits follow conventional syntax (`feat(agents): …`, `fix(indexing): …`) with a lowercase imperative summary. Pull request titles are also checked against the conventional commits format in CI, and titles ending in `!` receive the `breaking change` label automatically. Each PR should describe the change, link any GitHub issue, note API or schema impacts, and include before/after traces or logs when behavior changes. Update docs (README, website, or inline rustdoc) and add tests or benchmarks alongside functional work. Before requesting review, run the full lint and test suite listed above. ## Tooling & Environment Notes The workspace pins `stable` in `rust-toolchain.toml`; use the same channel unless a nightly tool is explicitly required. Dependency hygiene is enforced with `cargo deny --workspace`, and spelling checks may run via `typos`. Store local credentials with `mise` or environment variables—never commit secrets. ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. ## [0.32.1](https://github.com/bosun-ai/swiftide/compare/v0.32.0...v0.32.1) - 2025-11-08 ### New features - [8bca0ef](https://github.com/bosun-ai/swiftide/commit/8bca0efa246e6adac061006f5f72cc9dd038cc8f) *(integrations/tree-sitter)* Add C# support ([#967](https://github.com/bosun-ai/swiftide/pull/967)) - [da35870](https://github.com/bosun-ai/swiftide/commit/da358708c83459c7f990027759fa5c56a2b647b9) Custom schema for fail tool ([#966](https://github.com/bosun-ai/swiftide/pull/966)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.32.0...0.32.1 ## [0.32.0](https://github.com/bosun-ai/swiftide/compare/v0.31.3...v0.32.0) - 2025-11-05 ### New features - [9ae3331](https://github.com/bosun-ai/swiftide/commit/9ae33317bbcbf5e65e3aa7eb0bf378190b7c33b5) *(agents)* [**breaking**] Improve toolspec api with schemars and support all possible types ([#940](https://github.com/bosun-ai/swiftide/pull/940)) **BREAKING CHANGE**: macro-level `json_type` overrides beyond the basic primitives are no longer enforced; rely on Rust type inference or provide an explicit schemars-derived struct/custom schema when specific shapes are required - [a0cc8d7](https://github.com/bosun-ai/swiftide/commit/a0cc8d73a6ce9a82a03a78e8f83957d3c1455584) *(agents)* Stop with args with optional schema ([#950](https://github.com/bosun-ai/swiftide/pull/950)) - [8ad7d97](https://github.com/bosun-ai/swiftide/commit/8ad7d97b6911bd3c676c79a2d5318c31dad23e9f) *(agents)* Add configurable timeouts to commands and local executor ([#963](https://github.com/bosun-ai/swiftide/pull/963)) - [29289d3](https://github.com/bosun-ai/swiftide/commit/29289d37cb9c49fba89376c125194fc430c57a37) *(agents)* [**breaking**] Add working directories for executor and commands ([#941](https://github.com/bosun-ai/swiftide/pull/941)) **BREAKING CHANGE**: Add working directories for executor and commands ([#941](https://github.com/bosun-ai/swiftide/pull/941)) - [ce724e5](https://github.com/bosun-ai/swiftide/commit/ce724e56034d717aafde08bb6c2d9dc163c66caf) *(agents/mcp)* Prefix mcp tools with the server name ([#958](https://github.com/bosun-ai/swiftide/pull/958)) ### Bug fixes - [04cd88b](https://github.com/bosun-ai/swiftide/commit/04cd88b74c7a0dd962c093181884db0afe7b6d2d) *(docs)* Replace `feature(doc_auto_cfg)` with `doc(auto_cfg)` - [7873ce5](https://github.com/bosun-ai/swiftide/commit/7873ce5941a7abf8ed60df4ec2ea8a7a4c1d1316) *(integrations/openai)* Simplefy responses api and improve chat completion request ergonomics ([#956](https://github.com/bosun-ai/swiftide/pull/956)) - [24328d0](https://github.com/bosun-ai/swiftide/commit/24328d07e61a4f02679ee6b63a38561d12acefd4) *(macros)* Ensure deny_unknown_attributes is set on generated args ([#948](https://github.com/bosun-ai/swiftide/pull/948)) - [54245d0](https://github.com/bosun-ai/swiftide/commit/54245d0e70aff580d0e12d68e174026edfdb4801) Update async-openai and fix responses api ([#964](https://github.com/bosun-ai/swiftide/pull/964)) - [72a6c92](https://github.com/bosun-ai/swiftide/commit/72a6c92764aeda4e88a7cf18d26ce600b7ba8a28) Force additionalProperties properly on completion requests ([#949](https://github.com/bosun-ai/swiftide/pull/949)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.31.3...0.32.0 ## [0.31.3](https://github.com/bosun-ai/swiftide/compare/v0.31.2...v0.31.3) - 2025-10-06 ### New features - [a189ae6](https://github.com/bosun-ai/swiftide/commit/a189ae6de51571810f98cf58f9fdb58e7707f29a) *(integrations/openai)* Opt-in responses api ([#943](https://github.com/bosun-ai/swiftide/pull/943)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.31.2...0.31.3 ## [0.31.2](https://github.com/bosun-ai/swiftide/compare/v0.31.1...v0.31.2) - 2025-09-23 ### New features - [f35c9b5](https://github.com/bosun-ai/swiftide/commit/f35c9b507e11f76ff7e78de35843b3310a25f3db) *(agents)* Add builder lite methods to SystemPrompt - [9f533f5](https://github.com/bosun-ai/swiftide/commit/9f533f57b2c7ed4ac1988f9e3567cda42f64b824) *(agents)* Add helpers to retrieve or mutate the system prompt - [febb7eb](https://github.com/bosun-ai/swiftide/commit/febb7eb282af98ce1124636cb66a8819265e3585) *(agents)* Support appending any kind of string to default SystemPrompt - [992478e](https://github.com/bosun-ai/swiftide/commit/992478ec8912554f73e3af6467784fd9326461c5) *(integrations/tree-sitter)* Splitter support for PHP ([#932](https://github.com/bosun-ai/swiftide/pull/932)) ### Bug fixes - [5df7a48](https://github.com/bosun-ai/swiftide/commit/5df7a483bed7d980bceef5e69fd7e1415da7563f) *(agents)* Only log error tool calls if error after hook - [54dceec](https://github.com/bosun-ai/swiftide/commit/54dceece5b939a0b534891ee5902593920a3fdeb) *(agents/local-executor)* Also respect workdir in read file and write file - [6a688b4](https://github.com/bosun-ai/swiftide/commit/6a688b4be6a5a443ac72aa8ec0165ce6a0bebf11) *(agents/local-executor)* Respect workdir when running commands - [5b01c58](https://github.com/bosun-ai/swiftide/commit/5b01c5854432569638fa54225268e48b4133178d) *(langfuse)* Use swiftide Usage in SimplePrompt ([#929](https://github.com/bosun-ai/swiftide/pull/929)) ### Miscellaneous - [ec1e301](https://github.com/bosun-ai/swiftide/commit/ec1e301eec2793613186b9e3bcb02de52741b936) *(agents)* Explicit read file test for local executor - [8882a53](https://github.com/bosun-ai/swiftide/commit/8882a538f30c7ff457dcb3a1d48e623fbc5aad1d) Improve tests for control tools ([#928](https://github.com/bosun-ai/swiftide/pull/928)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.31.1...0.31.2 ## [0.31.1](https://github.com/bosun-ai/swiftide/compare/v0.31.0...v0.31.1) - 2025-09-16 ### Docs - [866b77a](https://github.com/bosun-ai/swiftide/commit/866b77a8c33b6b7935f260c1df099d89492cb048) *(readme)* Use raw links for images so they work on crates/docs - [513c143](https://github.com/bosun-ai/swiftide/commit/513c143cd11ae6ddda48f73012844f1f6d026ef7) *(readme)* Remove double back-to-top **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.31.0...0.31.1 ## [0.31.0](https://github.com/bosun-ai/swiftide/compare/v0.30.1...v0.31.0) - 2025-09-16 ### New features - [ad6655d](https://github.com/bosun-ai/swiftide/commit/ad6655dc448defc3a9ef8401f0528da11e16a256) *(agents)* Add helper to remove default stop tool from agent builder - [708ebe4](https://github.com/bosun-ai/swiftide/commit/708ebe436b4d2e9456723cfc95557071f2c636c9) *(agents)* Implement From for SystemPromptBuilder - [db79f21](https://github.com/bosun-ai/swiftide/commit/db79f21c323abca462a5f469814c4c03cc949b7e) *(agents/tasks)* Add helper to create instant transitions from node ids - [ac7cd22](https://github.com/bosun-ai/swiftide/commit/ac7cd2209e1792b693acbde251a1aa756bb35541) *(indexing)* [**breaking**] Prepare for multi modal and node transformations with generic indexing ([#899](https://github.com/bosun-ai/swiftide/pull/899)) **BREAKING CHANGE**: Indexing pipelines are now generic over their inner type. This is a major change that enables major cool stuff in the future. Most of Swiftide still runs on Node, and will be migrated when needed/appropriate. A `TextNode` alias is provided and most indexing traits now take the node's inner generic parameter as Input/Output associated types. - [4e20804](https://github.com/bosun-ai/swiftide/commit/4e20804cc78a90e61a1c816abe5810b2a34007af) *(integrations)* More convenient usage reporting via callback ([#897](https://github.com/bosun-ai/swiftide/pull/897)) - [5923532](https://github.com/bosun-ai/swiftide/commit/592353259018b39d4ce43b4a15a9dea1aa1d2904) *(integrations/openai, core)* Add `StructuredPrompt` and implement for OpenAI ([#912](https://github.com/bosun-ai/swiftide/pull/912)) - [d2681d5](https://github.com/bosun-ai/swiftide/commit/d2681d53ce235439885ace40ac08a6d4a058259a) Integrate with Langfuse via tracing and make traces consistent and pretty ([#907](https://github.com/bosun-ai/swiftide/pull/907)) - [b3f18cd](https://github.com/bosun-ai/swiftide/commit/b3f18cd00f9019496274142aa89342da115c6843) Add convenience helpers to get ToolOutput values as ref ### Bug fixes - [0071b72](https://github.com/bosun-ai/swiftide/commit/0071b721520d585f36d1ec6ff90eb88d669da043) *(agents)* Replace tools when adding multiple with the same name - [dab4cf7](https://github.com/bosun-ai/swiftide/commit/dab4cf771cd9a6d90ae0985c83171fd87b213cba) *(integrations)* Remove sync requirement in future from `on_usage_async` - [6702314](https://github.com/bosun-ai/swiftide/commit/6702314eb6d937353324ce601f2a35c2a13d4cc1) *(langfuse)* Ensure all data is on the right generation span ([#913](https://github.com/bosun-ai/swiftide/pull/913)) - [e389c8b](https://github.com/bosun-ai/swiftide/commit/e389c8ba72435ba1c1af109934b2b580fb6be7c1) *(langfuse)* Set type field correctly on `SimplePrompt` ### Miscellaneous - [5ba9a7d](https://github.com/bosun-ai/swiftide/commit/5ba9a7db6f844687b04c5fa5d9a2119f456108c6) *(agents)* Implement default for `AgentCanFail` tool - [412dacb](https://github.com/bosun-ai/swiftide/commit/412dacb554d2b1478f3286a47352a6daed3079b9) *(agents/tasks)* Clean up closure api for node registration - [478d583](https://github.com/bosun-ai/swiftide/commit/478d5830fa194b880595b2c2ef9ef409cc5b34c4) *(openai)* Remove double `include_usage` in complete_stream ### Docs - [2117190](https://github.com/bosun-ai/swiftide/commit/211719038d1912f3ee3f165cdb721c216fa48286) Update blog post links in readme - [d5e0323](https://github.com/bosun-ai/swiftide/commit/d5e0323691a22a0b413d14d02e3bafb391e9dd7a) Update readme - [a574860](https://github.com/bosun-ai/swiftide/commit/a5748604d14e10c4010384e020e09c6082d2a7c1) Update readme ### Style - [7081e29](https://github.com/bosun-ai/swiftide/commit/7081e291216491618fb07e1ac3f947a99b140c7f) Fmt **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.30.1...0.31.0 ## [0.30.1](https://github.com/bosun-ai/swiftide/compare/v0.30.0...v0.30.1) - 2025-08-19 ### Bug fixes - [0114573](https://github.com/bosun-ai/swiftide/commit/011457367b7bfdc207f1f6d9ebfcbf2a2de4ac58) *(agents)* Explicitly handle out of bounds and empty edge cases for message history - [1005ac2](https://github.com/bosun-ai/swiftide/commit/1005ac219e2078c6ee12b050a7e73d48ef7f46a5) *(core)* Export tokenizer traits from the root crate - [e4c01e1](https://github.com/bosun-ai/swiftide/commit/e4c01e14fbe89cb5a16beddcb3819b66c7f1a087) *(integrations/tiktoken)* Tiktoken feature flag in root crate - [d56496d](https://github.com/bosun-ai/swiftide/commit/d56496d60719eea3752f849aee2a780eb435130e) *(integrations/tiktoken)* Fix my inability to count in late hours ### Miscellaneous - [352bf40](https://github.com/bosun-ai/swiftide/commit/352bf40ad5f74778bf41f00cff936805b8633b30) *(core)* Implement AsRef for ChatMessage - [aadfb7b](https://github.com/bosun-ai/swiftide/commit/aadfb7b89fe1fd6d04f27bc7209458de3571d1cc) *(integrations/openai)* Concise debug logs and more verbose trace - [f975d40](https://github.com/bosun-ai/swiftide/commit/f975d40beccdebd98c896d8492243a489a9b287b) *(query)* Reduce debugging noise for queries ### Style - [6a744e0](https://github.com/bosun-ai/swiftide/commit/6a744e0290ebceca3c14b675a35a460f532c4cff) Fix typos **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.30.0...0.30.1 ## [0.30.0](https://github.com/bosun-ai/swiftide/compare/v0.29.0...v0.30.0) - 2025-08-16 ### New features - [dc574b4](https://github.com/bosun-ai/swiftide/commit/dc574b41b259f430bb4dc38338416ea1aa9480bb) *(agents)* Multi agent setup with graph-like Tasks ([#861](https://github.com/bosun-ai/swiftide/pull/861)) - [8740762](https://github.com/bosun-ai/swiftide/commit/87407626ef75c254fae0a677148609738fd64ccc) *(agents)* Allow mutating an existing system prompt in the builder ([#887](https://github.com/bosun-ai/swiftide/pull/887)) - [4bbf207](https://github.com/bosun-ai/swiftide/commit/4bbf207637a1aebe4e0d5b2d4030c3d1f99d4c1c) *(agents/local-executor)* Allow clearing, adding and removing env variable ([#875](https://github.com/bosun-ai/swiftide/pull/875)) - [7873493](https://github.com/bosun-ai/swiftide/commit/787349329e34956bcd205b8da64bb241c15c8e65) *(agents/local-executor)* Support running inline shebang scripts ([#874](https://github.com/bosun-ai/swiftide/pull/874)) - [a6d4379](https://github.com/bosun-ai/swiftide/commit/a6d43794ae8e549b3716ef15344471b22041cbc1) Proper streaming backoff for Chat Completion ([#895](https://github.com/bosun-ai/swiftide/pull/895)) ### Bug fixes - [2b8e138](https://github.com/bosun-ai/swiftide/commit/2b8e1389b630283a2e8c55b9997f09322b7378a9) *(openai)* More gracefully allow handling streaming errors if the client is decorated ([#891](https://github.com/bosun-ai/swiftide/pull/891)) - [f2948b5](https://github.com/bosun-ai/swiftide/commit/f2948b596d7c91c518e700c5d2589fba5a45b649) *(pipeline)* Revert cache nodes after they've been successfully ran ([#800](https://github.com/bosun-ai/swiftide/pull/800)) ([#852](https://github.com/bosun-ai/swiftide/pull/852)) ### Performance - [63a91bd](https://github.com/bosun-ai/swiftide/commit/63a91bd2d8290cbd20f4ae3914d820192ef160d2) Use Cow to in Prompt ### Miscellaneous - [09f421b](https://github.com/bosun-ai/swiftide/commit/09f421bcc934721ab5fcf3dc2808fe5beefcc9a2) Update rmcp and schemars ([#881](https://github.com/bosun-ai/swiftide/pull/881)) ### Docs - [84ffa45](https://github.com/bosun-ai/swiftide/commit/84ffa4507e57b252f72204f8e0df67191d97fe72) Minimal updates for tasks **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.29.0...0.30.0 ## [0.29.0](https://github.com/bosun-ai/swiftide/compare/v0.28.1...v0.29.0) - 2025-07-29 ### New features - [25a86fa](https://github.com/bosun-ai/swiftide/commit/25a86fa0403581c3c5ddc5bd237bee98f41bc153) *(agents)* Lots of utility functions for agents ([#862](https://github.com/bosun-ai/swiftide/pull/862)) - [a70840b](https://github.com/bosun-ai/swiftide/commit/a70840b4dca983bd23b54f1f7cf12b33d60b733c) *(openai)* Add helper to set the end user field for requests - [f8ddeba](https://github.com/bosun-ai/swiftide/commit/f8ddebaf57001671516db193140c2e5618000206) *(tree-sitter)* Add html support for splitting and parsing ([#850](https://github.com/bosun-ai/swiftide/pull/850)) ### Bug fixes - [aaa5cd9](https://github.com/bosun-ai/swiftide/commit/aaa5cd99d0316dcdc46afb922bbcefdfaa97da86) *(agents)* Add user message before invoking hooks ([#853](https://github.com/bosun-ai/swiftide/pull/853)) - [592be04](https://github.com/bosun-ai/swiftide/commit/592be049b798d80d6dadce6317889a14404643c8) *(agents)* Reduce verbosity of streaming hook ([#854](https://github.com/bosun-ai/swiftide/pull/854)) - [9778295](https://github.com/bosun-ai/swiftide/commit/977829550d58301f53f663b4c25fa5650ab15359) *(agents)* Ensure error causes are always accessible - [efd35da](https://github.com/bosun-ai/swiftide/commit/efd35da842288616abd55c789b727265bc549ffb) *(docs)* Fix prompt doctests - [e2670c0](https://github.com/bosun-ai/swiftide/commit/e2670c04d471dd7654e903e79f48bcfe61603b9f) *(duckdb)* Force install and update extensions ([#851](https://github.com/bosun-ai/swiftide/pull/851)) - [6a7ea3b](https://github.com/bosun-ai/swiftide/commit/6a7ea3b1472df209669fdf1231f0bdf4ebe6007f) *(redis)* Redis instrumentation only at trace level ### Miscellaneous - [0a8ce37](https://github.com/bosun-ai/swiftide/commit/0a8ce373325fac53946c245209afcd8bb7b2caa9) Public chat completion streaming types - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.28.1...0.29.0 ## [0.28.1](https://github.com/bosun-ai/swiftide/compare/v0.28.0...v0.28.1) - 2025-07-01 ### New features - [c671e6a](https://github.com/bosun-ai/swiftide/commit/c671e6aec7b381235f8450a8be0cbc766df72985) *(agents)* Add is_approved() and is_refused() to ToolFeedback ### Bug fixes - [68c5cda](https://github.com/bosun-ai/swiftide/commit/68c5cdafc6e457739bcfeb12d2810350659f2979) *(agents)* Prevent stack overflow when ToolExecutor has ambigious refs - [07198d2](https://github.com/bosun-ai/swiftide/commit/07198d26389e1606e6e0f552e411196f42cf6600) *(duckdb)* Resolve 'x is an existing extension' - [e8ecc2f](https://github.com/bosun-ai/swiftide/commit/e8ecc2ff532efd07bd21e5350b8d2b6f600ca1c6) *(qdrant)* Re-export the full qdrant client - [242b8f5](https://github.com/bosun-ai/swiftide/commit/242b8f5e3d427967aa238115047a58bb9debad3b) *(qdrant)* Re-export qdrant::Filter properly **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.28.0...0.28.1 ## [0.28.0](https://github.com/bosun-ai/swiftide/compare/v0.27.2...v0.28.0) - 2025-06-30 ### New features - [9d11386](https://github.com/bosun-ai/swiftide/commit/9d11386c155773fcc77a60591cd57bc366044c71) Token usage metrics for embeddings, SimplePrompt and ChatCompletion with metric-rs ([#813](https://github.com/bosun-ai/swiftide/pull/813)) - [59c8b9c](https://github.com/bosun-ai/swiftide/commit/59c8b9cef721c3861a9d352c7fbef28e27d2f649) Stream files from tool executor for indexing ([#835](https://github.com/bosun-ai/swiftide/pull/835)) ### Bug fixes - [ba6ec04](https://github.com/bosun-ai/swiftide/commit/ba6ec0485dc950e83e91e6a8102becc0e8a13158) *(pipeline)* Cache nodes after they've been successfully ran ([#800](https://github.com/bosun-ai/swiftide/pull/800)) - [d98827c](https://github.com/bosun-ai/swiftide/commit/d98827c9cd7bb476fdda0ef2ebb6939150b8781c) *(qdrant)* Re-export qdrant::Filter - [275efcd](https://github.com/bosun-ai/swiftide/commit/275efcdf91e85ed4327ffa948dcebe5903b178fa) Mark Loader as Send + Sync - [5974b72](https://github.com/bosun-ai/swiftide/commit/5974b72de4da2fc18d1f76adde02d02035104d5c) Integrations metrics depends on core/metrics ### Miscellaneous - [2f8c7cc](https://github.com/bosun-ai/swiftide/commit/2f8c7cc96b194264a47a8fe21abb7af5c63204f6) *(deps)* Up all crates ([#837](https://github.com/bosun-ai/swiftide/pull/837)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.27.2...0.28.0 ## [0.27.2](https://github.com/bosun-ai/swiftide/compare/v0.27.1...v0.27.2) - 2025-06-26 ### New features - [66cd7e9](https://github.com/bosun-ai/swiftide/commit/66cd7e9349673a77d8cc79e6b5acab8d56078a42) *(qdrant)* Add support for a filter in hybrid search ([#830](https://github.com/bosun-ai/swiftide/pull/830)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.27.1...0.27.2 ## [0.27.1](https://github.com/bosun-ai/swiftide/compare/v0.27.0...v0.27.1) - 2025-06-12 ### Bug fixes - [0892151](https://github.com/bosun-ai/swiftide/commit/0892151d2d02c30e38fa8629c386eaf4475da7f8) *(duckdb)* Avoid panic if duckdb gets created twice ([#818](https://github.com/bosun-ai/swiftide/pull/818)) - [0815923](https://github.com/bosun-ai/swiftide/commit/081592334f2bd8c2da30535b4e1b51e8ddd15834) *(tool-executor)* Remove conflicting implementation of AsRef for Output ### Miscellaneous - [2b64410](https://github.com/bosun-ai/swiftide/commit/2b644109796c8870d29fa1b54f6a0802cae9aaf8) *(tool-executor)* Implement AsRef for CommandOutput **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.27.0...0.27.1 ## [0.27.0](https://github.com/bosun-ai/swiftide/compare/v0.26.0...v0.27.0) - 2025-06-09 ### New features - [c636eba](https://github.com/bosun-ai/swiftide/commit/c636ebaa2eb8d4ace1b5a370698c5f2817fc9c99) *(agents)* [**breaking**] Context is now generic over its backend ([#810](https://github.com/bosun-ai/swiftide/pull/810)) **BREAKING CHANGE**: The signature is now slightly different for the AgentContext. If you have implemented your own for i.e. a persisted solution, if it's *just that*, the implementation is now a lot more straightforward with the `MessageHistory` trait. - [3c937a8](https://github.com/bosun-ai/swiftide/commit/3c937a8ed4f7d28798a24b0d893f1613cd298493) *(agents)* Add helpers for creating tool errors ([#805](https://github.com/bosun-ai/swiftide/pull/805)) - [9e831d3](https://github.com/bosun-ai/swiftide/commit/9e831d3eb072748ebb21c9a16cd7d807b4d42469) *(agents)* [**breaking**] Easy human-in-the-loop flows by decorating tools ([#790](https://github.com/bosun-ai/swiftide/pull/790)) **BREAKING CHANGE**: The `Tool` trait now receives a `ToolCall` as argument instead of an `Option<&str>`. The latter is still accessible via `tool_call.args()`. - [814c217](https://github.com/bosun-ai/swiftide/commit/814c2174c742ff4277246505537070726ce8af92) *(duckdb)* Hybrid Search ([#807](https://github.com/bosun-ai/swiftide/pull/807)) - [254bd3a](https://github.com/bosun-ai/swiftide/commit/254bd3a32ffbd4d06abd6a4f3950a2b8556dc310) *(integrations)* Add kafka as loader and persist support ([#808](https://github.com/bosun-ai/swiftide/pull/808)) - [19a2e94](https://github.com/bosun-ai/swiftide/commit/19a2e94d262cc68c629d88b6b02a72bb9b159036) *(integrations)* Add support for Google Gemini ([#754](https://github.com/bosun-ai/swiftide/pull/754)) - [990fa5e](https://github.com/bosun-ai/swiftide/commit/990fa5e9edffebd9b70da6b57fa454f7318d642d) *(redis)* Support `MessageHistory` for redis ([#811](https://github.com/bosun-ai/swiftide/pull/811)) ### Bug fixes - [ca119bd](https://github.com/bosun-ai/swiftide/commit/ca119bdc473140437abb1bf14b496bb7bd9378de) *(agents)* Ensure approved / refused tool calls are in new completions ([#799](https://github.com/bosun-ai/swiftide/pull/799)) - [df6a12d](https://github.com/bosun-ai/swiftide/commit/df6a12dabe855f351acc3e0d104048321cb9bc0e) *(agents)* Ensure agents with no tools still have the stop tool - [cd57d12](https://github.com/bosun-ai/swiftide/commit/cd57d1207ced8651a277526d706bc3b7703912c0) *(openai)* Opt-out streaming accumulated response and only get the delta ([#809](https://github.com/bosun-ai/swiftide/pull/809)) - [da2d604](https://github.com/bosun-ai/swiftide/commit/da2d604e7e6209c83f382cf6de44f5f5c2042596) *(redb)* Explicit lifetime in table definition ### Miscellaneous - [7ac92a4](https://github.com/bosun-ai/swiftide/commit/7ac92a4f2ff4b1d1ba7e86c90c4f6c5c025cabc9) *(agents)* Direct access to executor via context ([#794](https://github.com/bosun-ai/swiftide/pull/794)) - [a21883b](https://github.com/bosun-ai/swiftide/commit/a21883b219a0079c1edc1d3c36d1d06ac906ba18) *(agents)* [**breaking**] Improved naming for existing messages and message history in default context **BREAKING CHANGE**: Improved naming for existing messages and message history in default context - [40bfa9c](https://github.com/bosun-ai/swiftide/commit/40bfa9c2d5685e54f247becb49698f8fdc347172) *(indexing)* Implement ChunkerTransformer for closures - [c8d7ab9](https://github.com/bosun-ai/swiftide/commit/c8d7ab90c86e674d5df5f4985121e4e81d1e4a37) *(integrations)* Improved warning when a qdrant collection exists - [d6769eb](https://github.com/bosun-ai/swiftide/commit/d6769eba0b87750fd3173ba73315973f720263ec) *(tree-sitter)* Implement Eq, Hash and AsRefStr for SupportedLanguages - [04ec29d](https://github.com/bosun-ai/swiftide/commit/04ec29d7240a8542ccd1d530bb9b104bcd57631e) Consistent logging for indexing pipeline ([#792](https://github.com/bosun-ai/swiftide/pull/792)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.26.0...0.27.0 ## [0.26.0](https://github.com/bosun-ai/swiftide/compare/v0.25.1...v0.26.0) - 2025-05-06 ### New features - [11051d5](https://github.com/bosun-ai/swiftide/commit/11051d5a1df6ea158ee84de274767fbdc70cc74e) *(agents)* `tools` on `Agent` is now public and can be used in hooks - [ebe68c1](https://github.com/bosun-ai/swiftide/commit/ebe68c104b8198b80ee5ee1f451c3272ce36841c) *(integrations)* Streaming chat completions for anthropic ([#773](https://github.com/bosun-ai/swiftide/pull/773)) - [7f5b345](https://github.com/bosun-ai/swiftide/commit/7f5b345115a3443afc9b32ca54a292fae3f5d38b) *(integrations)* Streaming chat completions for OpenAI ([#741](https://github.com/bosun-ai/swiftide/pull/741)) - [e2278fb](https://github.com/bosun-ai/swiftide/commit/e2278fb133e51f15025e114135a2bc29157242ee) *(integrations)* Customize common default settings for OpenAI requests ([#775](https://github.com/bosun-ai/swiftide/pull/775)) - [c563cf2](https://github.com/bosun-ai/swiftide/commit/c563cf270c60957dbb948113fb2299ec5eb7ed58) *(treesitter)* Add elixir support ([#776](https://github.com/bosun-ai/swiftide/pull/776)) - [13ae991](https://github.com/bosun-ai/swiftide/commit/13ae991b632cc95d1ae0bc7107146a145af59c74) Add usage to chat completion response ([#774](https://github.com/bosun-ai/swiftide/pull/774)) ### Bug fixes - [7836f9f](https://github.com/bosun-ai/swiftide/commit/7836f9ff31f2abeab966f80a91eab32054e61ff1) *(agents)* Use an RwLock to properly close a running MCP server - [0831c98](https://github.com/bosun-ai/swiftide/commit/0831c982cd6bb0b442396268c0681c908b6dadc2) *(openai)* Disable parallel tool calls by default ### Miscellaneous - [18dc99c](https://github.com/bosun-ai/swiftide/commit/18dc99ca1f597586ffed36e163f04f7c3689d2be) *(integrations)* Use generics for all openai variants ([#764](https://github.com/bosun-ai/swiftide/pull/764)) - [2a9d062](https://github.com/bosun-ai/swiftide/commit/2a9d062c6e19721c49c6233690ac71e9e28b6a04) *(openai)* Consistent exports across providers - [4df6dbf](https://github.com/bosun-ai/swiftide/commit/4df6dbf17fd4b87afc2cf7159c6518fcebc27438) Export macros from main crate and enable them by default ([#778](https://github.com/bosun-ai/swiftide/pull/778)) - [8b30fde](https://github.com/bosun-ai/swiftide/commit/8b30fde5e20ecbd4f0387c26e441d39f78ddca32) Rust like its 2024 ([#763](https://github.com/bosun-ai/swiftide/pull/763)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.25.1...0.26.0 ## [0.25.1](https://github.com/bosun-ai/swiftide/compare/v0.25.0...v0.25.1) - 2025-04-17 ### Bug fixes - [7102091](https://github.com/bosun-ai/swiftide/commit/710209123ba6972cd11fb0f3d364c9c83478e184) *(agents)* AgentBuilder and AgentBuilderError should be public **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.25.0...0.25.1 # Changelog All notable changes to this project will be documented in this file. ## [0.25.0](https://github.com/bosun-ai/swiftide/compare/v0.24.0...v0.25.0) - 2025-04-16 ### New features - [4959ddf](https://github.com/bosun-ai/swiftide/commit/4959ddfe00e0424215dd9bd3e8a6acb579cc056c) *(agents)* Restore agents from an existing message history ([#742](https://github.com/bosun-ai/swiftide/pull/742)) - [6efd15b](https://github.com/bosun-ai/swiftide/commit/6efd15bf7b88d8f8656c4017676baf03a3bb510e) *(agents)* Agents now take an Into Prompt when queried ([#743](https://github.com/bosun-ai/swiftide/pull/743)) ### Bug fixes - [5db4de2](https://github.com/bosun-ai/swiftide/commit/5db4de2f0deb2028f5ffaf28b4d26336840e908c) *(agents)* Properly support nullable types for MCP tools ([#740](https://github.com/bosun-ai/swiftide/pull/740)) - [dd2ca86](https://github.com/bosun-ai/swiftide/commit/dd2ca86b214e8268262075a513711d6b9c793115) *(agents)* Do not log twice if mcp failed to stop - [5fea2e2](https://github.com/bosun-ai/swiftide/commit/5fea2e2acdca0782f88d4274bb8e106b48e1efe4) *(indexing)* Split pipeline concurrently ([#749](https://github.com/bosun-ai/swiftide/pull/749)) ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies - [0f2605a](https://github.com/bosun-ai/swiftide/commit/0f2605a61240d2c99e10ce6f5a91e6568343a78b) Pretty print RAGAS output ([#745](https://github.com/bosun-ai/swiftide/pull/745)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.24.0...0.25.0 ## [0.24.0](https://github.com/bosun-ai/swiftide/compare/v0.23.0...v0.24.0) - 2025-04-11 ### New features - [3117fc6](https://github.com/bosun-ai/swiftide/commit/3117fc62c146b0bf0949adb3cfe4e6c7f40427f7) Introduce LanguageModelError for LLM traits and an optional backoff decorator ([#630](https://github.com/bosun-ai/swiftide/pull/630)) ### Bug fixes - [0134dae](https://github.com/bosun-ai/swiftide/commit/0134daebef5d47035e986d30e1fa8f2c751c2c48) *(agents)* Gracefully stop mcp service on drop ([#734](https://github.com/bosun-ai/swiftide/pull/734)) ### Miscellaneous - [e872c5b](https://github.com/bosun-ai/swiftide/commit/e872c5b24388754b371d9f0c7faad8647ad4733b) Core test utils available behind feature flag ([#730](https://github.com/bosun-ai/swiftide/pull/730)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.23.0...0.24.0 ## [0.23.0](https://github.com/bosun-ai/swiftide/compare/v0.22.8...v0.23.0) - 2025-04-08 ### New features - [fca4165](https://github.com/bosun-ai/swiftide/commit/fca4165c5be4b14cdc3d20ed8215ef64c5fd69a9) *(agents)* Return typed errors and yield error in `on_stop` ([#725](https://github.com/bosun-ai/swiftide/pull/725)) - [29352e6](https://github.com/bosun-ai/swiftide/commit/29352e6d3dc51779f3202e0e9936bf72e0b61605) *(agents)* Add `on_stop` hook and `stop` now takes a `StopReason` ([#724](https://github.com/bosun-ai/swiftide/pull/724)) - [a85cd8e](https://github.com/bosun-ai/swiftide/commit/a85cd8e2d014f198685ee6bfcfdf17f7f34acf91) *(macros)* Support generics in Derive for tools ([#720](https://github.com/bosun-ai/swiftide/pull/720)) - [52c44e9](https://github.com/bosun-ai/swiftide/commit/52c44e9b610c0ba4bf144881c36eacc3a0d10e53) Agent mcp client support ([#658](https://github.com/bosun-ai/swiftide/pull/658)) ````text Adds support for agents to use tools from MCP servers. All transports are supported via the `rmcp` crate. Additionally adds the possibility to add toolboxes to agents (of which MCP is one). Tool boxes declare their available tools at runtime, like tool box. ```` ### Miscellaneous - [69706ec](https://github.com/bosun-ai/swiftide/commit/69706ec6630b70ea9d332c151637418736437a99) [**breaking**] Remove templates ([#716](https://github.com/bosun-ai/swiftide/pull/716)) ````text Template / prompt interface got confusing and bloated. This removes `Template` fully, and changes Prompt such that it can either ref to a one-off, or to a template named compiled in the swiftide repository. ```` **BREAKING CHANGE**: This removes `Template` from Swiftide and simplifies the whole setup significantly. The internal Swiftide Tera repository can still be extended like with Templates. Same behaviour with less code and abstractions. **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.8...0.23.0 ## [0.22.8](https://github.com/bosun-ai/swiftide/compare/v0.22.7...v0.22.8) - 2025-04-02 ### Bug fixes - [6b4dfca](https://github.com/bosun-ai/swiftide/commit/6b4dfca822f39b3700d60e6ea31b9b48ccd6d56f) Tool macros should work with latest darling version ([#712](https://github.com/bosun-ai/swiftide/pull/712)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.7...0.22.8 ## [0.22.7](https://github.com/bosun-ai/swiftide/compare/v0.22.6...v0.22.7) - 2025-03-30 ### Bug fixes - [b0001fb](https://github.com/bosun-ai/swiftide/commit/b0001fbb12cf6bb85fc4d5a8ef0968219e8c78db) *(duckdb)* Upsert is now opt in as it requires duckdb >= 1.2 ([#708](https://github.com/bosun-ai/swiftide/pull/708)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.6...0.22.7 ## [0.22.6](https://github.com/bosun-ai/swiftide/compare/v0.22.5...v0.22.6) - 2025-03-27 ### New features - [a05b3c8](https://github.com/bosun-ai/swiftide/commit/a05b3c8e7c4224c060215c34490b2ea7729592bf) *(macros)* Support optional values and make them even nicer to use ([#703](https://github.com/bosun-ai/swiftide/pull/703)) ### Bug fixes - [1866d5a](https://github.com/bosun-ai/swiftide/commit/1866d5a081f40123e607208d04403fb98f34c057) *(integrations)* Loosen up duckdb requirements even more and make it more flexible for version requirements ([#706](https://github.com/bosun-ai/swiftide/pull/706)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.5...0.22.6 # Changelog All notable changes to this project will be documented in this file. ## [0.22.5](https://github.com/bosun-ai/swiftide/compare/v0.22.4...v0.22.5) - 2025-03-23 ### New features - [eb4e044](https://github.com/bosun-ai/swiftide/commit/eb4e0442293e17722743aa2b88d8dd7582dd9236) Estimate tokens for OpenAI like apis with tiktoken-rs ([#699](https://github.com/bosun-ai/swiftide/pull/699)) ### Miscellaneous - [345c57a](https://github.com/bosun-ai/swiftide/commit/345c57a663dd0d315a28f0927c5d598ba21d019d) Improve file loader logging ([#695](https://github.com/bosun-ai/swiftide/pull/695)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.4...0.22.5 ## [0.22.4](https://github.com/bosun-ai/swiftide/compare/v0.22.3...v0.22.4) - 2025-03-17 ### Bug fixes - [4ec00bb](https://github.com/bosun-ai/swiftide/commit/4ec00bb0fed214f27629f32569406bfa2c786dd7) *(integrations)* Add chrono/utc feature flag when using qdrant ([#684](https://github.com/bosun-ai/swiftide/pull/684)) ````text The Qdrant integration calls chrono::Utc::now(), which requires the now feature flag to be enabled in the chrono crate when using qdrant ```` - [0b204d9](https://github.com/bosun-ai/swiftide/commit/0b204d90a68978bb4b75516c537a56d665771c55) Ensure `groq`, `fastembed`, `test-utils` features compile individually ([#689](https://github.com/bosun-ai/swiftide/pull/689)) ### Miscellaneous - [bd4ef97](https://github.com/bosun-ai/swiftide/commit/bd4ef97f2b9207b5ac03d610b76bdb3440e3d5c0) Include filenames in errors in file io ([#694](https://github.com/bosun-ai/swiftide/pull/694)) ````text Uses fs-err crate to automatically include filenames in the error messages ```` - [9453e06](https://github.com/bosun-ai/swiftide/commit/9453e06d5338c99cec5f51b085739cc30a5f12be) Use std::sync::Mutex instead of tokio mutex ([#693](https://github.com/bosun-ai/swiftide/pull/693)) - [b3456e2](https://github.com/bosun-ai/swiftide/commit/b3456e25af99f661aff1779ae5f2d4da460f128c) Log qdrant setup messages at debug level ([#696](https://github.com/bosun-ai/swiftide/pull/696)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.3...0.22.4 ## [0.22.3](https://github.com/bosun-ai/swiftide/compare/v0.22.2...v0.22.3) - 2025-03-13 ### Miscellaneous - [834fcd3](https://github.com/bosun-ai/swiftide/commit/834fcd3b2270904bcfe8998a7015de15626128a8) Update duckdb to 1.2.1 ([#680](https://github.com/bosun-ai/swiftide/pull/680)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.2...0.22.3 ## [0.22.2](https://github.com/bosun-ai/swiftide/compare/v0.22.1...v0.22.2) - 2025-03-11 ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies - [e1c097d](https://github.com/bosun-ai/swiftide/commit/e1c097da885374ec9320c1847a7dda7c5d9d41cb) Disable default features on all dependencies ([#675](https://github.com/bosun-ai/swiftide/pull/675)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.1...0.22.2 # Changelog All notable changes to this project will be documented in this file. ## [0.22.1](https://github.com/bosun-ai/swiftide/compare/v0.22.0...v0.22.1) - 2025-03-09 ### New features - [474d612](https://github.com/bosun-ai/swiftide/commit/474d6122596e71132e35fcb181302dfed7794561) *(integrations)* Add Duckdb support ([#578](https://github.com/bosun-ai/swiftide/pull/578)) ````text Adds support for Duckdb. Persist, Retrieve (Simple and Custom), and NodeCache are implemented. Metadata and full upsert are not. Once 1.2 has its issues fixed, it's easy to add. ```` - [4cf417c](https://github.com/bosun-ai/swiftide/commit/4cf417c6a818fbec2641ad6576b4843412902bf6) *(treesitter)* C and C++ support for splitter only ([#663](https://github.com/bosun-ai/swiftide/pull/663)) ### Bug fixes - [590eaeb](https://github.com/bosun-ai/swiftide/commit/590eaeb3c6b5c14c56c925e038528326f88508a1) *(integrations)* Make openai parallel_tool_calls an Option ([#664](https://github.com/bosun-ai/swiftide/pull/664)) ````text o3-mini needs to omit parallel_tool_calls - so we need to allow for a None option to not include that field ```` ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies - [d864c7e](https://github.com/bosun-ai/swiftide/commit/d864c7e72ba01d3f187e4f6ab6ad3e6244ae0dc4) Downgrade duckdb to 1.1.1 and fix ci ([#671](https://github.com/bosun-ai/swiftide/pull/671)) - [9b685b3](https://github.com/bosun-ai/swiftide/commit/9b685b3281d9694c5faa58890a9aba32cba90f1c) Update and loosen deps ([#670](https://github.com/bosun-ai/swiftide/pull/670)) - [a64ca16](https://github.com/bosun-ai/swiftide/commit/a64ca1656b903a680cc70ac7b33ac40d9d356d4a) Tokio_stream features should include `time` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.22.0...0.22.1 ## [0.22.0](https://github.com/bosun-ai/swiftide/compare/v0.21.1...v0.22.0) - 2025-03-03 ### New features - [a754846](https://github.com/bosun-ai/swiftide/commit/a7548463367023d3e5a3a25dd84f06632b372f18) *(agents)* Implement Serialize and Deserialize for chat messages ````text Persist, retry later, evaluate it completions in a script, you name it. ```` - [0a592c6](https://github.com/bosun-ai/swiftide/commit/0a592c67621f3eba4ad6e0bfd5a539e19963cf17) *(indexing)* Add `iter()` for file loader ([#655](https://github.com/bosun-ai/swiftide/pull/655)) ````text Allows playing with the iterator outside of the stream. Relates to https://github.com/bosun-ai/kwaak/issues/337 ```` - [57116e9](https://github.com/bosun-ai/swiftide/commit/57116e9a30c722f47398be61838cc1ef4d0bbfac) Groq ChatCompletion ([#650](https://github.com/bosun-ai/swiftide/pull/650)) ````text Use the new generics to _just-make-it-work_. ```` - [4fd3259](https://github.com/bosun-ai/swiftide/commit/4fd325921555a14552e33b2481bc9dfcf0c313fc) Continue Agent on Tool Failure ([#628](https://github.com/bosun-ai/swiftide/pull/628)) ````text Ensure tool calls and responses are always balanced, even when the tool retry limit is reached https://github.com/bosun-ai/kwaak/issues/313 ```` ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.21.1...0.22.0 ## [0.21.1](https://github.com/bosun-ai/swiftide/compare/v0.21.0...v0.21.1) - 2025-02-28 ### Bug fixes - [f418c5e](https://github.com/bosun-ai/swiftide/commit/f418c5ee2f0d3ee87fb3715ec6b1d7ecc80bf714) *(ci)* Run just a single real rerank test to please the flaky gods - [e387e82](https://github.com/bosun-ai/swiftide/commit/e387e826200e1bc0a608e1f680537751cfc17969) *(lancedb)* Update Lancedb to 0.17 and pin Arrow to a lower version ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.21.0...0.21.1 ## [0.21.0](https://github.com/bosun-ai/swiftide/compare/v0.20.1...v0.21.0) - 2025-02-25 ### New features - [12a9873](https://github.com/bosun-ai/swiftide/commit/12a98736ab171c25d860000bb95b1e6e318758fb) *(agents)* Improve flexibility for tool generation (#641) ````text Previously ToolSpec and name in the `Tool` trait worked with static. With these changes, there is a lot more flexibility, allowing for i.e. run-time tool generation. ```` ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.20.1...0.21.0 ## [0.20.1](https://github.com/bosun-ai/swiftide/compare/v0.20.0...v0.20.1) - 2025-02-21 ### Bug fixes - [0aa1248](https://github.com/bosun-ai/swiftide/commit/0aa124819d836f37d1fcaf88e6f88b5affb46cf9) *(indexing)* Handle invalid utf-8 in fileloader lossy (#632) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.20.0...0.20.1 # Changelog All notable changes to this project will be documented in this file. ## [0.20.0](https://github.com/bosun-ai/swiftide/compare/v0.19.0...v0.20.0) - 2025-02-18 ### New features - [5d85d14](https://github.com/bosun-ai/swiftide/commit/5d85d142339d24c793bd89a907652bede0d1c94d) *(agents)* Add support for numbers, arrays and booleans in tool args (#562) ````text Add support for numbers, arrays and boolean types in the `#[swiftide_macros::tool]` attribute macro. For enum and object a custom implementation is now properly supported as well, but not via the macro. For now, tools using Derive also still need a custom implementation. ```` - [b09afed](https://github.com/bosun-ai/swiftide/commit/b09afed72d463d8b59ffa2b325eb6a747c88c87f) *(query)* Add support for reranking with `Fastembed` and multi-document retrieval (#508) ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.19.0...0.20.0 ## [0.19.0](https://github.com/bosun-ai/swiftide/compare/v0.18.2...v0.19.0) - 2025-02-13 ### New features - [fa5112c](https://github.com/bosun-ai/swiftide/commit/fa5112c9224fdf5984d26db669f04dedc8ebb561) *(agents)* By default retry failed tools with LLM up to 3 times (#609) ````text Specifically meant for LLMs sending invalid JSON, these tool calls are now retried by feeding back the error into the LLM up to a limit (default 3). ```` - [14f4778](https://github.com/bosun-ai/swiftide/commit/14f47780b4294be3a9fa3670aa18a952ad7e9d6e) *(integrations)* Parallel tool calling in OpenAI is now configurable (#611) ````text Adds support reasoning models in agents and for chat completions. ```` - [37a1a2c](https://github.com/bosun-ai/swiftide/commit/37a1a2c7bfd152db56ed929e0ea1ab99080e640d) *(integrations)* Add system prompts as `system` instead of message in Anthropic requests ### Bug fixes - [ab27c75](https://github.com/bosun-ai/swiftide/commit/ab27c75b8f4a971cb61e88b26d94231afd35c871) *(agents)* Add back anyhow catch all for failed tools - [2388f18](https://github.com/bosun-ai/swiftide/commit/2388f187966d996ede4ff42c71521238b63d129c) *(agents)* Use name/arg hash on tool retries (#612) - [da55664](https://github.com/bosun-ai/swiftide/commit/da5566473e3f8874fce427ceb48a15d002737d07) *(integrations)* Scraper should stop when finished (#614) ### Miscellaneous - [990a8ea](https://github.com/bosun-ai/swiftide/commit/990a8eaeffdbd447bb05a0b01aa65a39a7c9cacf) *(deps)* Update tree-sitter (#616) - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.18.2...0.19.0 ## [0.18.2](https://github.com/bosun-ai/swiftide/compare/v0.18.1...v0.18.2) - 2025-02-11 ### New features - [50ffa15](https://github.com/bosun-ai/swiftide/commit/50ffa156e28bb085a61a376bab71c135bc09622f) Anthropic support for prompts and agents (#602) ### Bug fixes - [8cf70e0](https://github.com/bosun-ai/swiftide/commit/8cf70e08787d1376ba20001cc9346767d8bd84ef) *(integrations)* Ensure anthropic tool call format is consistent with specs ### Miscellaneous - [98176c6](https://github.com/bosun-ai/swiftide/commit/98176c603b61e3971ca5583f9f4346eb5b962d51) Clippy **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.18.1...0.18.2 # Changelog All notable changes to this project will be documented in this file. ## [0.18.1](https://github.com/bosun-ai/swiftide/compare/v0.18.0...v0.18.1) - 2025-02-09 ### New features - [78bf0e0](https://github.com/bosun-ai/swiftide/commit/78bf0e004049c852d4e32c0cd67725675b1250f9) *(agents)* Add optional limit for agent iterations (#599) - [592e5a2](https://github.com/bosun-ai/swiftide/commit/592e5a2ca4b0f09ba6a9b20cef105539cb7a7909) *(integrations)* Support Azure openai via generics (#596) - [c8f2eed](https://github.com/bosun-ai/swiftide/commit/c8f2eed9964341ac2dad611fc730dc234436430a) *(tree-sitter)* Add solidity support (#597) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.18.0...0.18.1 # Changelog All notable changes to this project will be documented in this file. ## [0.18.0](https://github.com/bosun-ai/swiftide/compare/v0.17.5...v0.18.0) - 2025-02-01 ### New features - [de46656](https://github.com/bosun-ai/swiftide/commit/de46656f80c5cf68cc192d21b5f34eb3e0667a14) *(agents)* Add `on_start` hook (#586) - [c551f1b](https://github.com/bosun-ai/swiftide/commit/c551f1becfd1750ce480a00221a34908db61e42f) *(integrations)* OpenRouter support (#589) ````text Adds OpenRouter support. OpenRouter allows you to use any LLM via their own api (with a minor upsell). ```` ### Bug fixes - [3ea5839](https://github.com/bosun-ai/swiftide/commit/3ea583971c0d2cc5ef0594eaf764ea149bacd1d8) *(redb)* Disable per-node tracing ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.lock dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.17.5...0.18.0 ## [0.17.5](https://github.com/bosun-ai/swiftide/compare/v0.17.4...v0.17.5) - 2025-01-27 ### New features - [825a52e](https://github.com/bosun-ai/swiftide/commit/825a52e70a74e4621d370485346a78d61bf5d7a9) *(agents)* Tool description now also accepts paths (i.e. a const) (#580) ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.lock dependencies - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.17.4...0.17.5 ## [0.17.4](https://github.com/bosun-ai/swiftide/compare/v0.17.3...v0.17.4) - 2025-01-24 ### Bug fixes - [0d9e250](https://github.com/bosun-ai/swiftide/commit/0d9e250e2512fe9c66d5dfd2ac688dcd56bd07e9) *(tracing)* Use `or_current()` to prevent orphaned tracing spans (#573) ````text When a span is emitted that would be selected by the subscriber, but we instrument its closure with a span that would not be selected by the subscriber, the span would be emitted as an orphan (with a new `trace_id`) making them hard to find and cluttering dashboards. This situation is also documented here: https://docs.rs/tracing/latest/tracing/struct.Span.html#method.or_current ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.17.3...0.17.4 ## [0.17.3](https://github.com/bosun-ai/swiftide/compare/v0.17.2...v0.17.3) - 2025-01-24 ### New features - [8e22442](https://github.com/bosun-ai/swiftide/commit/8e2244241f16fff77591cf04f40725ad0b05ca81) *(integrations)* Support Qdrant 1.13 (#571) ### Bug fixes - [c5408a9](https://github.com/bosun-ai/swiftide/commit/c5408a96fbed6207022eb493da8d2cbb0fea7ca6) *(agents)* Io::Error should always be a NonZeroExit error for tool executors (#570) ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.lock dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.17.2...0.17.3 ## [0.17.2](https://github.com/bosun-ai/swiftide/compare/v0.17.1...v0.17.2) - 2025-01-21 ### Bug fixes - [47db5ab](https://github.com/bosun-ai/swiftide/commit/47db5ab138384a6c235a90024470e9ab96751cc8) *(agents)* Redrive uses the correct pointer and works as intended **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.17.1...0.17.2 ## [0.17.1](https://github.com/bosun-ai/swiftide/compare/v0.17.0...v0.17.1) - 2025-01-20 ### New features - [e4e4468](https://github.com/bosun-ai/swiftide/commit/e4e44681b65b07b5f1e987ce468bdcda61eb30da) *(agents)* Implement AgentContext for smart dyn pointers - [70181d9](https://github.com/bosun-ai/swiftide/commit/70181d9642aa2c0a351b9f42be1a8cdbd83c9075) *(agents)* Add pub accessor for agent context (#558) - [274d9d4](https://github.com/bosun-ai/swiftide/commit/274d9d46f39ac2e28361c4881c6f8f7e20dd8753) *(agents)* Preprocess tool calls to fix common, fixable errors (#560) ````text OpenAI has a tendency to sometimes send double keys. With this, Swiftide will now take the first key and ignore any duplicates after that. Sets the stage for any future preprocessing before it gets strictly parsed by serde. ```` - [0f0f491](https://github.com/bosun-ai/swiftide/commit/0f0f491b2621ad82389a57bdb521fcf4021b7d7a) *(integrations)* Add Dashscope support (#543) ````text --------- ```` ### Bug fixes - [b2b15ac](https://github.com/bosun-ai/swiftide/commit/b2b15ac073e4f6b035239791a056fbdf6f6e704e) *(openai)* Enable strict mode for tool calls (#561) ````text Ensures openai sticks much better to the schema and avoids accidental mistakes. ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.17.0...0.17.1 ## [0.17.0](https://github.com/bosun-ai/swiftide/compare/v0.16.4...v0.17.0) - 2025-01-16 ### New features - [835c35e](https://github.com/bosun-ai/swiftide/commit/835c35e7d74811daa90f7ca747054d1919633058) *(agents)* Redrive completions manually on failure (#551) ````text Sometimes LLMs fail a completion without deterministic errors, or the user case where you just want to retry. `redrive` can now be called on a context, popping any new messages (if any), and making the messages available again to the agent. ```` - [f83f3f0](https://github.com/bosun-ai/swiftide/commit/f83f3f03bbf6a9591b54521dde91bf1a5ed19c5c) *(agents)* Implement ToolExecutor for common dyn pointers (#549) - [7f85735](https://github.com/bosun-ai/swiftide/commit/7f857358e46e825494ba927dffb33c3afa0d762e) *(query)* Add custom lancedb query generation for lancedb search (#518) - [ce4e34b](https://github.com/bosun-ai/swiftide/commit/ce4e34be42ce1a0ab69770d03695bd67f99a8739) *(tree-sitter)* Add golang support (#552) ````text Seems someone conveniently forgot to add Golang support for the splitter. ```` ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.lock dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.16.4...0.17.0 ## [0.16.4](https://github.com/bosun-ai/swiftide/compare/v0.16.3...v0.16.4) - 2025-01-12 ### New features - [c919484](https://github.com/bosun-ai/swiftide/commit/c9194845faa12b8a0fcecdd65f8ec9d3d221ba08) Ollama via async-openai with chatcompletion support (#545) ````text Adds support for chatcompletions (agents) for ollama. SimplePrompt and embeddings now use async-openai underneath. Copy pasted as I expect some differences in the future. ```` ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.16.3...0.16.4 ## [0.16.3](https://github.com/bosun-ai/swiftide/compare/v0.16.2...v0.16.3) - 2025-01-10 ### New features - [b66bd79](https://github.com/bosun-ai/swiftide/commit/b66bd79070772d7e1bfe10a22531ccfd6501fc2a) *(fastembed)* Add support for jina v2 code (#541) ````text Add support for jina v2 code in fastembed. ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.16.2...0.16.3 ## [0.16.2](https://github.com/bosun-ai/swiftide/compare/v0.16.1...v0.16.2) - 2025-01-08 ### Bug fixes - [2226755](https://github.com/bosun-ai/swiftide/commit/2226755f367d9006870a2dea2063655a7901d427) Explicit cast on tools to Box to make analyzer happy (#536) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.16.1...0.16.2 ## [0.16.1](https://github.com/bosun-ai/swiftide/compare/v0.16.0...v0.16.1) - 2025-01-06 ### Bug fixes - [d198bb0](https://github.com/bosun-ai/swiftide/commit/d198bb0807f5d5b12a51bc76721cc945be8e65b9) *(prompts)* Skip rendering prompts if no context and forward as is (#530) ````text Fixes an issue if strings suddenly include jinja style values by mistake. Bonus performance boost. ```` - [4e8d59f](https://github.com/bosun-ai/swiftide/commit/4e8d59fbc0fbe72dd0f8d6a95e6e335280eb88e3) *(redb)* Log errors and return uncached instead of panicing (#531) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.16.0...0.16.1 ## [0.16.0](https://github.com/bosun-ai/swiftide/compare/v0.15.0...v0.16.0) - 2025-01-02 ### New features - [52e341e](https://github.com/bosun-ai/swiftide/commit/52e341ee9777d04f9fb07054980ba087c55c033e) *(lancedb)* Public method for opening table (#514) - [3254bd3](https://github.com/bosun-ai/swiftide/commit/3254bd34d0eeb038c8aa6ea56ac2940b3ca81960) *(query)* Generic templates with document rendering (#520) ````text Reworks `PromptTemplate` to a more generic `Template`, such that they can also be used elsewhere. This deprecates `PromptTemplate`. As an example, an optional `Template` in the `Simple` answer transformer, which can be used to customize the output of retrieved documents. This has excellent synergy with the metadata changes in #504. ```` - [235780b](https://github.com/bosun-ai/swiftide/commit/235780b941a0805b69541f0f4c55c3404091baa8) *(query)* Documents as first class citizens (#504) ````text For simple RAG, just adding the content of a retrieved document might be enough. However, in more complex use cases, you might want to add metadata as well, as is or for conditional formatting. For instance, when dealing with large amounts of chunked code, providing the path goes a long way. If generated metadata is good enough, could be useful as well. With this retrieved Documents are treated as first class citizens, including any metadata as well. Additionally, this also paves the way for multi retrieval (and multi modal). ```` - [584695e](https://github.com/bosun-ai/swiftide/commit/584695e4841a3c9341e521b81e9f254270b3416e) *(query)* Add custom SQL query generation for pgvector search (#478) ````text Adds support for custom retrieval queries with the sqlx query builder for PGVector. Puts down the fundamentals for custom query building for any retriever. --------- ```` - [b55bf0b](https://github.com/bosun-ai/swiftide/commit/b55bf0b318042459a6983cf725078c4da662618b) *(redb)* Public database and table definition (#510) - [176378f](https://github.com/bosun-ai/swiftide/commit/176378f846ddecc3ddba74f6b423338b793f29b4) Implement traits for all Arc dynamic dispatch (#513) ````text If you use i.e. a `Persist` or a `NodeCache` outside swiftide as well, and you already have it Arc'ed, now it just works. ```` - [dc9881e](https://github.com/bosun-ai/swiftide/commit/dc9881e48da7fb5dc744ef33b1c356b4152d00d3) Allow opt out of pipeline debug truncation ### Bug fixes - [2831101](https://github.com/bosun-ai/swiftide/commit/2831101daa2928b5507116d9eb907d98fb77bf50) *(lancedb)* Metadata should be nullable in lancedb (#515) - [c35df55](https://github.com/bosun-ai/swiftide/commit/c35df5525d4d88cfb9ada89a060e1ab512b471af) *(macros)* Explicit box dyn cast fixing Rust Analyzer troubles (#523) ### Miscellaneous - [1bbbb0e](https://github.com/bosun-ai/swiftide/commit/1bbbb0e548cafa527c34856bd9ac6f76aca2ab5f) Clippy **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.15.0...0.16.0 ## [0.15.0](https://github.com/bosun-ai/swiftide/compare/v0.14.4...v0.15.0) - 2024-12-23 ### New features - [a1b9a2d](https://github.com/bosun-ai/swiftide/commit/a1b9a2d37715420d3e2cc80d731e3713a22c7c50) *(query)* Ensure concrete names for transformations are used when debugging (#496) - [7779c44](https://github.com/bosun-ai/swiftide/commit/7779c44de3581ac865ac808637c473525d27cabb) *(query)* Ensure query pipeline consistently debug logs in all other stages too - [55dde88](https://github.com/bosun-ai/swiftide/commit/55dde88df888b60a7ccae5a68ba03d20bc1f57df) *(query)* Debug full retrieved documents when debug mode is enabled (#495) - [66031ba](https://github.com/bosun-ai/swiftide/commit/66031ba27b946add0533775423d468abb3187604) *(query)* Log query pipeline answer on debug (#497) ### Miscellaneous - [d255772](https://github.com/bosun-ai/swiftide/commit/d255772cc933c839e3aaaffccd343acf75dcb251) *(agents)* Rename `CommandError::FailedWithOutput` to `CommandError::NonZeroExit` (#484) ````text Better describes what is going on. I.e. `rg` exits with 1 if nothing is found, tests generally do the same if they fail. ```` - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.4...0.15.0 ## [0.14.4](https://github.com/bosun-ai/swiftide/compare/v0.14.3...v0.14.4) - 2024-12-11 ### New features - [7211559](https://github.com/bosun-ai/swiftide/commit/7211559936d8b5e16a3b42f9c90b42a39426be8a) *(agents)* **EXPERIMENTAL** Agents in Swiftide (#463) ````text Agents are coming to Swiftide! We are still ironing out all the kinks, while we make it ready for a proper release. You can already experiment with agents, see the rustdocs for documentation, and an example in `/examples`, and feel free to contact us via github or discord. Better documentation, examples, and tutorials are coming soon. Run completions in a loop, define tools with two handy macros, customize the agent by hooking in on lifecycle events, and much more. Besides documentation, expect a big release for what we build this for soon! 🎉 ```` - [3751f49](https://github.com/bosun-ai/swiftide/commit/3751f49201c71398144a8913a4443f452534def2) *(query)* Add support for single embedding retrieval with PGVector (#406) ### Miscellaneous - [5ce4d21](https://github.com/bosun-ai/swiftide/commit/5ce4d21725ff9b0bb7f9da8fe026075fde9fc9a5) Clippy and deps fixes for 1.83 (#467) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.3...0.14.4 ## [0.14.3](https://github.com/bosun-ai/swiftide/compare/v0.14.2...v0.14.3) - 2024-11-20 ### New features - [1774b84](https://github.com/bosun-ai/swiftide/commit/1774b84f00a83fe69af4a2b6a6daf397d4d9b32d) *(integrations)* Add PGVector support for indexing ([#392](https://github.com/bosun-ai/swiftide/pull/392)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.2...0.14.3 ## [0.14.2](https://github.com/bosun-ai/swiftide/compare/v0.14.1...v0.14.2) - 2024-11-08 ### Bug fixes - [3924322](https://github.com/bosun-ai/swiftide/commit/39243224d739a76cf2b60204fc67819055b7bc6f) *(querying)* Query pipeline is now properly send and sync when possible ([#425](https://github.com/bosun-ai/swiftide/pull/425)) ### Miscellaneous - [52198f7](https://github.com/bosun-ai/swiftide/commit/52198f7fe76376a42c1fec8945bda4bf3e6971d4) Improve local dev build speed ([#434](https://github.com/bosun-ai/swiftide/pull/434)) ````text - **Tokio on rt-multi-thread only** - **Remove manual checks from lancedb integration test** - **Ensure all deps in workspace manifest** - **Remove unused deps** - **Remove examples and benchmarks from default members** ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.1...0.14.2 ## [0.14.1](https://github.com/bosun-ai/swiftide/compare/v0.14.0...v0.14.1) - 2024-10-27 ### Bug fixes - [5bbcd55](https://github.com/bosun-ai/swiftide/commit/5bbcd55de65d73d7908e91c96f120928edb6b388) Revert 0.14 release as mistralrs is unpublished ([#417](https://github.com/bosun-ai/swiftide/pull/417)) ````text Revert the 0.14 release as `mistralrs` is unpublished and unfortunately cannot be released. ```` ### Miscellaneous - [07c2661](https://github.com/bosun-ai/swiftide/commit/07c2661b7a7cdf75cdba12fab0ca91866793f727) Re-release 0.14 without mistralrs ([#419](https://github.com/bosun-ai/swiftide/pull/419)) ````text - **Revert "fix: Revert 0.14 release as mistralrs is unpublished ([#417](https://github.com/bosun-ai/swiftide/pull/417))"** - **Fix changelog** ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.14.0...0.14.1 ## [0.14.0](https://github.com/bosun-ai/swiftide/compare/v0.13.4...v0.14.0) - 2024-10-27 ### Bug fixes - [551a9cb](https://github.com/bosun-ai/swiftide/commit/551a9cb769293e42e15bae5dca3ab677be0ee8ea) *(indexing)* [**breaking**] Node ID no longer memoized ([#414](https://github.com/bosun-ai/swiftide/pull/414)) ````text As @shamb0 pointed out in [#392](https://github.com/bosun-ai/swiftide/pull/392), there is a potential issue where Node ids are get cached before chunking or other transformations, breaking upserts and potentially resulting in data loss. ```` **BREAKING CHANGE**: This PR reworks Nodes with a builder API and a private id. Hence, manually creating nodes no longer works. In the future, all the fields are likely to follow the same pattern, so that we can decouple the inner fields from the Node's implementation. - [c091ffa](https://github.com/bosun-ai/swiftide/commit/c091ffa6be792b0bd7bb03d604e26e40b2adfda8) *(indexing)* Use atomics for key generation in memory storage ([#415](https://github.com/bosun-ai/swiftide/pull/415)) ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.4...0.14.0 ## [0.13.4](https://github.com/bosun-ai/swiftide/compare/v0.13.3...v0.13.4) - 2024-10-21 ### Bug fixes - [47455fb](https://github.com/bosun-ai/swiftide/commit/47455fb04197a4b51142e2fb4c980e42ac54d11e) *(indexing)* Visibility of ChunkMarkdown builder should be public - [2b3b401](https://github.com/bosun-ai/swiftide/commit/2b3b401dcddb2cb32214850b9b4dbb0481943d38) *(indexing)* Improve splitters consistency and provide defaults ([#403](https://github.com/bosun-ai/swiftide/pull/403)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.3...0.13.4 # Changelog All notable changes to this project will be documented in this file. ## [0.13.3](https://github.com/bosun-ai/swiftide/compare/v0.13.2...v0.13.3) - 2024-10-11 ### Bug fixes - [2647f16](https://github.com/bosun-ai/swiftide/commit/2647f16dc164eb5230d8f7c6d71e31663000cb0d) *(deps)* Update rust crate text-splitter to 0.17 ([#366](https://github.com/bosun-ai/swiftide/pull/366)) - [d74d85b](https://github.com/bosun-ai/swiftide/commit/d74d85be3bd98706349eff373c16443b9c45c4f0) *(indexing)* Add missing `Embed::batch_size` implementation ([#378](https://github.com/bosun-ai/swiftide/pull/378)) - [95f78d3](https://github.com/bosun-ai/swiftide/commit/95f78d3412951c099df33149c57817338a76553d) *(tree-sitter)* Compile regex only once ([#371](https://github.com/bosun-ai/swiftide/pull/371)) ````text Regex compilation is not cheap, use a static with a oncelock instead. ```` ### Miscellaneous - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.2...0.13.3 ## [0.13.2](https://github.com/bosun-ai/swiftide/compare/v0.13.1...v0.13.2) - 2024-10-05 ### New features - [4b13aa7](https://github.com/bosun-ai/swiftide/commit/4b13aa7d76dfc7270870682e2f757f066a99ba4e) *(core)* Add support for cloning all trait objects ([#355](https://github.com/bosun-ai/swiftide/pull/355)) ````text For instance, if you have a `Box`, you can now clone into an owned copy and more effectively use the available generics. This also works for borrowed trait objects. ```` - [ed3da52](https://github.com/bosun-ai/swiftide/commit/ed3da52cf89b2384ec6f07c610c591b3eda2fa28) *(indexing)* Support Redb as embedable nodecache ([#346](https://github.com/bosun-ai/swiftide/pull/346)) ````text Adds support for Redb as an embeddable node cache, allowing full local app development without needing external services. ```` ### Bug fixes - [06f8336](https://github.com/bosun-ai/swiftide/commit/06f83361c52010a451e8b775ce9c5d67057edbc5) *(indexing)* Ensure `name()` returns concrete name on trait objects ([#351](https://github.com/bosun-ai/swiftide/pull/351)) ### Miscellaneous - [8237c28](https://github.com/bosun-ai/swiftide/commit/8237c2890df681c48117188e80cbad914b91e0fd) *(core)* Mock traits for testing should not have their docs hidden - [0000000](https://github.com/bosun-ai/swiftide/commit/0000000) Update Cargo.toml dependencies **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.1...0.13.2 ## [0.13.1](https://github.com/bosun-ai/swiftide/compare/v0.13.0...v0.13.1) - 2024-10-02 ### Bug fixes - [e6d9ec2](https://github.com/bosun-ai/swiftide/commit/e6d9ec2fe034c9d36fd730c969555c459606d42f) *(lancedb)* Should not error if table exists ([#349](https://github.com/bosun-ai/swiftide/pull/349)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.13.0...0.13.1 ## [0.13.0](https://github.com/bosun-ai/swiftide/compare/v0.12.3...v0.13.0) - 2024-09-26 ### New features - [7d8a57f](https://github.com/bosun-ai/swiftide/commit/7d8a57f54b2c73267dfaa3b3a32079b11d9b32bc) *(indexing)* [**breaking**] Removed duplication of batch_size ([#336](https://github.com/bosun-ai/swiftide/pull/336)) **BREAKING CHANGE**: The batch size of batch transformers when indexing is now configured on the batch transformer. If no batch size or default is configured, a configurable default is used from the pipeline. The default batch size is 256. - [fd110c8](https://github.com/bosun-ai/swiftide/commit/fd110c8efeb3af538d4e51d033b6df02e90e05d9) *(tree-sitter)* Add support for Java 22 ([#309](https://github.com/bosun-ai/swiftide/pull/309)) ### Bug fixes - [23b96e0](https://github.com/bosun-ai/swiftide/commit/23b96e08b4e0f10f5faea0b193b404c9cd03f47f) *(tree-sitter)* [**breaking**] SupportedLanguages are now non-exhaustive ([#331](https://github.com/bosun-ai/swiftide/pull/331)) **BREAKING CHANGE**: SupportedLanguages are now non-exhaustive. This means that matching on SupportedLanguages will now require a catch-all arm. This change was made to allow for future languages to be added without breaking changes. ### Miscellaneous - [923a8f0](https://github.com/bosun-ai/swiftide/commit/923a8f0663e7d2b7138f54069f7a74c3cf6663ed) *(fastembed,qdrant)* Better batching defaults ([#334](https://github.com/bosun-ai/swiftide/pull/334)) ```text Qdrant and FastEmbed now have a default batch size, removing the need to set it manually. The default batch size is 50 and 256 respectively. ``` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.12.3...0.13.0 ## [0.12.3](https://github.com/bosun-ai/swiftide/releases/tag/0.12.3) - 2024-09-23 ### New features - [da5df22](https://github.com/bosun-ai/swiftide/commit/da5df2230da81e9fe1e6ab74150511cbe1e3d769) *(tree-sitter)* Implement Serialize and Deserialize for SupportedLanguages ([#314](https://github.com/bosun-ai/swiftide/pull/314)) ### Bug fixes - [a756148](https://github.com/bosun-ai/swiftide/commit/a756148f85faa15b1a79db8ec8106f0e15e4d6a2) *(tree-sitter)* Fix javascript and improve tests ([#313](https://github.com/bosun-ai/swiftide/pull/313)) ````text As learned from [#309](https://github.com/bosun-ai/swiftide/pull/309), test coverage for the refs defs transformer was not great. There _are_ more tests in code_tree. Turns out, with the latest treesitter update, javascript broke as it was the only language not covered at all. ```` ### Miscellaneous - [e8e9d80](https://github.com/bosun-ai/swiftide/commit/e8e9d80f2b4fbfe7ca2818dc542ca0a907a17da5) *(docs)* Add documentation to query module ([#276](https://github.com/bosun-ai/swiftide/pull/276)) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/0.12.2...0.12.3 ## [v0.12.2](https://github.com/bosun-ai/swiftide/releases/tag/v0.12.2) - 2024-09-20 ### Docs - [d84814e](https://github.com/bosun-ai/swiftide/commit/d84814eef1bf12e485053fb69fb658d963100789) Fix broken documentation links and other cargo doc warnings (#304) by @tinco ````text Running `cargo doc --all-features` resulted in a lot of warnings. ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.12.1...v0.12.2 ## [v0.12.1](https://github.com/bosun-ai/swiftide/releases/tag/v0.12.1) - 2024-09-16 ### New features - [ec227d2](https://github.com/bosun-ai/swiftide/commit/ec227d25b987b7fd63ab1b3862ef19b14632bd04) *(indexing,query)* Add concise info log with transformation name by @timonv - [01cf579](https://github.com/bosun-ai/swiftide/commit/01cf579922a877bb78e0de20114ade501e5a63db) *(query)* Add query_mut for reusable query pipelines by @timonv - [081a248](https://github.com/bosun-ai/swiftide/commit/081a248e67292c1800837315ec53583be5e0cb82) *(query)* Improve query performance similar to indexing in 0.12 by @timonv - [8029926](https://github.com/bosun-ai/swiftide/commit/80299269054eb440e55a42667a7bcc9ba6514a7b) *(query,indexing)* Add duration in log output on pipeline completion by @timonv ### Bug fixes - [39b6ecb](https://github.com/bosun-ai/swiftide/commit/39b6ecb6175e5233b129f94876f95182b8bfcdc3) *(core)* Truncate long strings safely when printing debug logs by @timonv - [8b8ceb9](https://github.com/bosun-ai/swiftide/commit/8b8ceb9266827857859481c1fc4a0f0c40805e33) *(deps)* Update redis by @timonv - [16e9c74](https://github.com/bosun-ai/swiftide/commit/16e9c7455829100b9ae82305e5a1d2568264af9f) *(openai)* Reduce debug verbosity by @timonv - [6914d60](https://github.com/bosun-ai/swiftide/commit/6914d607717294467cddffa867c3d25038243fc1) *(qdrant)* Reduce debug verbosity when storing nodes by @timonv - [3d13889](https://github.com/bosun-ai/swiftide/commit/3d1388973b5e2a135256ae288d47dbde0399487f) *(query)* Reduce and improve debugging verbosity by @timonv - [133cf1d](https://github.com/bosun-ai/swiftide/commit/133cf1d0be09049ca3e90b45675a965bb2464cb2) *(query)* Remove verbose debug and skip self in instrumentation by @timonv - [ce17981](https://github.com/bosun-ai/swiftide/commit/ce179819ab75460453236723c7f9a89fd61fb99a) Clippy by @timonv - [a871c61](https://github.com/bosun-ai/swiftide/commit/a871c61ad52ed181d6f9cb6a66ed07bccaadee08) Fmt by @timonv ### Miscellaneous - [d62b047](https://github.com/bosun-ai/swiftide/commit/d62b0478872e460956607f52b72470b76eb32d91) *(ci)* Update testcontainer images and fix tests by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.12.0...v0.12.1 ## [v0.12.0](https://github.com/bosun-ai/swiftide/releases/tag/v0.12.0) - 2024-09-13 ### New features - [e902cb7](https://github.com/bosun-ai/swiftide/commit/e902cb7487221d3e88f13d88532da081e6ef8611) *(query)* Add support for filters in SimilaritySingleEmbedding (#298) by @timonv ````text Adds support for filters for Qdrant and Lancedb in SimilaritySingleEmbedding. Also fixes several small bugs and brings improved tests. ```` - [f158960](https://github.com/bosun-ai/swiftide/commit/f1589604d1e0cb42a07d5a48080e3d7ecb90ee38) Major performance improvements (#291) by @timonv ````text Futures that do not yield were not run in parallel properly. With this futures are spawned on a tokio worker thread by default. When embedding (fastembed) and storing a 85k row dataset, there's a ~1.35x performance improvement: image ~~Need to do one more test with IO bound futures as well. Pretty huge, not that it was slow.~~ With IO bound openai it's 1.5x. ```` ### Bug fixes - [f8314cc](https://github.com/bosun-ai/swiftide/commit/f8314ccdbe16ad7e6691899dd01f81a61b20180f) *(indexing)* Limit logged chunk to max 100 chars (#292) by @timonv - [f95f806](https://github.com/bosun-ai/swiftide/commit/f95f806a0701b14a3cad5da307c27c01325a264d) *(indexing)* Debugging nodes should respect utf8 char boundaries by @timonv - [8595553](https://github.com/bosun-ai/swiftide/commit/859555334d7e4129215b9f084d9f9840fac5ce36) Implement into_stream_boxed for all loaders by @timonv - [9464ca1](https://github.com/bosun-ai/swiftide/commit/9464ca123f08d8dfba3f1bfabb57e9af97018534) Bad embed error propagation (#293) by @timonv ````text - **fix(indexing): Limit logged chunk to max 100 chars** - **fix: Embed transformers must correctly propagate errors** ```` ### Miscellaneous - [45d8a57](https://github.com/bosun-ai/swiftide/commit/45d8a57d1afb4f16ad76b15236308d753cf45743) *(ci)* Use llm-cov preview via nightly and improve test coverage (#289) by @timonv ````text Fix test coverage in CI. Simplified the trait bounds on the query pipeline for now to make it all work and fit together, and added more tests to assert boxed versions of trait objects work in tests. ```` - [408f30a](https://github.com/bosun-ai/swiftide/commit/408f30ad8d007394ba971b314d399fcd378ffb61) *(deps)* Update testcontainers (#295) by @timonv - [37c4bd9](https://github.com/bosun-ai/swiftide/commit/37c4bd9f9ac97646adb2c4b99b8f7bf0bee4c794) *(deps)* Update treesitter (#296) by @timonv - [8d9e954](https://github.com/bosun-ai/swiftide/commit/8d9e9548ccc1b39e302ee42dd5058f50df13270f) Cargo update by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.11.1...v0.12.0 ## [v0.11.1](https://github.com/bosun-ai/swiftide/releases/tag/v0.11.1) - 2024-09-10 ### New features - [3c9491b](https://github.com/bosun-ai/swiftide/commit/3c9491b8e1ce31a030eaac53f56890629a087f70) Implemtent traits T for Box for indexing and query traits (#285) by @timonv ````text When working with trait objects, some pipeline steps now allow for Box as well. ```` ### Bug fixes - [dfa546b](https://github.com/bosun-ai/swiftide/commit/dfa546b310e71a7cb78a927cc8f0ee4e2046a592) Add missing parquet feature flag by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.11.0...v0.11.1 ## [v0.11.0](https://github.com/bosun-ai/swiftide/releases/tag/v0.11.0) - 2024-09-08 ### New features - [bdf17ad](https://github.com/bosun-ai/swiftide/commit/bdf17adf5d3addc84aaf45ad893b816cb46431e3) *(indexing)* Parquet loader (#279) by @timonv ````text Ingest and index data from parquet files. ```` - [a98dbcb](https://github.com/bosun-ai/swiftide/commit/a98dbcb455d33f0537cea4d3614da95f1a4b6554) *(integrations)* Add ollama embeddings support (#278) by @ephraimkunz ````text Update to the most recent ollama-rs, which exposes the batch embedding API Ollama exposes (https://github.com/pepperoni21/ollama-rs/pull/61). This allows the Ollama struct in Swiftide to implement `EmbeddingModel`. Use the same pattern that the OpenAI struct uses to manage separate embedding and prompt models. --------- ```` ### Miscellaneous - [873795b](https://github.com/bosun-ai/swiftide/commit/873795b31b3facb0cf5efa724cb391f7bf387fb0) *(ci)* Re-enable coverage via Coverals with tarpaulin (#280) by @timonv - [465de7f](https://github.com/bosun-ai/swiftide/commit/465de7fc952d66f4cd15002ef39aab0e7ec3ac26) Update CHANGELOG.md with breaking change by @timonv ### New Contributors * @ephraimkunz made their first contribution in [#278](https://github.com/bosun-ai/swiftide/pull/278) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.10.0...v0.11.0 ## [v0.10.0](https://github.com/bosun-ai/swiftide/releases/tag/v0.10.0) - 2024-09-06 ### Bug fixes - [5a724df](https://github.com/bosun-ai/swiftide/commit/5a724df895d35cfa606721d611afd073a23191de) [**breaking**] Rust 1.81 support (#275) by @timonv ````text Fixing id generation properly as per #272, will be merged in together. - **Clippy** - **fix(qdrant)!: Default hasher changed in Rust 1.81** ```` **BREAKING CHANGE**: Rust 1.81 support (#275) ### Docs - [3711f6f](https://github.com/bosun-ai/swiftide/commit/3711f6fb2b51e97e4606b744cc963c04b44b6963) *(readme)* Fix date (#273) by @dzvon ````text I suppose this should be 09-02. ```` ### New Contributors * @dzvon made their first contribution in [#273](https://github.com/bosun-ai/swiftide/pull/273) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.9.2...v0.10.0 ## [v0.9.2](https://github.com/bosun-ai/swiftide/releases/tag/v0.9.2) - 2024-09-04 ### New features - [84e9bae](https://github.com/bosun-ai/swiftide/commit/84e9baefb366f0a949ae7dcbdd8f97931da0b4be) *(indexing)* Add chunker for text with text_splitter (#270) by @timonv - [387fbf2](https://github.com/bosun-ai/swiftide/commit/387fbf29c2bce06284548f9af146bb3969562761) *(query)* Hybrid search for qdrant in query pipeline (#260) by @timonv ````text Implement hybrid search for qdrant with their new Fusion search. Example in /examples includes an indexing and query pipeline, included the example answer as well. ```` ### Docs - [064c7e1](https://github.com/bosun-ai/swiftide/commit/064c7e157775a7aaf9628a39f941be35ce0be99a) *(readme)* Update intro by @timonv - [1dc4c90](https://github.com/bosun-ai/swiftide/commit/1dc4c90436c9c8c8d0eb080e300afce53090c73e) *(readme)* Add new blog links by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.9.1...v0.9.2 ## [v0.9.1](https://github.com/bosun-ai/swiftide/releases/tag/v0.9.1) - 2024-09-01 ### New features - [b891f93](https://github.com/bosun-ai/swiftide/commit/b891f932e43b9c76198d238bcde73a6bb1dfbfdb) *(integrations)* Add fluvio as loader support (#243) by @timonv ````text Adds Fluvio as a loader support, enabling Swiftide indexing streams to process messages from a Fluvio topic. ```` - [c00b6c8](https://github.com/bosun-ai/swiftide/commit/c00b6c8f08fca46451387f3034d3d53805f3e401) *(query)* Ragas support (#236) by @timonv ````text Work in progress on support for ragas as per https://github.com/explodinggradients/ragas/issues/1165 and #232 Add an optional evaluator to a pipeline. Evaluators need to handle transformation events in the query pipeline. The Ragas evaluator captures the transformations as per https://docs.ragas.io/en/latest/howtos/applications/data_preparation.html. You can find a working notebook here https://github.com/bosun-ai/swiftide-tutorial/blob/c510788a625215f46575415161659edf26fc1fd5/ragas/notebook.ipynb with a pipeline using it here https://github.com/bosun-ai/swiftide-tutorial/pull/1 ```` - [a1250c1](https://github.com/bosun-ai/swiftide/commit/a1250c1cef57e2b74760fd31772e106993a3b079) LanceDB support (#254) by @timonv ````text Add LanceDB support for indexing and querying. LanceDB separates compute from storage, where storage can be local or hosted elsewhere. ```` ### Bug fixes - [f92376d](https://github.com/bosun-ai/swiftide/commit/f92376d551a3bf4fe39d81a64c4328a742677669) *(deps)* Update rust crate aws-sdk-bedrockruntime to v1.46.0 (#247) by @renovate[bot] - [732a166](https://github.com/bosun-ai/swiftide/commit/732a166f388d4aefaeec694103e3d1ff57655d69) Remove no default features from futures-util by @timonv ### Miscellaneous - [9b257da](https://github.com/bosun-ai/swiftide/commit/9b257dadea6c07f720ac4ea447342b2f6d91d0ec) Default features cleanup (#262) by @timonv ````text Integrations are messy and pull a lot in. A potential solution is to disable default features, only add what is actually required, and put the responsibility at users if they need anything specific. Feature unification should then take care of the rest. ```` ### Docs - [fb381b8](https://github.com/bosun-ai/swiftide/commit/fb381b8896a5fc863a4185445ce51fefb99e6c11) *(readme)* Copy improvements (#261) by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.9.0...v0.9.1 ## [v0.9.0](https://github.com/bosun-ai/swiftide/releases/tag/v0.9.0) - 2024-08-15 ### New features - [2443933](https://github.com/bosun-ai/swiftide/commit/24439339a9b935befcbcc92e56c01c5048605138) *(qdrant)* Add access to inner client for custom operations (#242) by @timonv - [4fff613](https://github.com/bosun-ai/swiftide/commit/4fff613b461e8df993327cb364cabc65cd5901d8) *(query)* Add concurrency on query pipeline and add query_all by @timonv ### Bug fixes - [4e31c0a](https://github.com/bosun-ai/swiftide/commit/4e31c0a6cdc6b33e4055f611dc48d3aebf7514ae) *(deps)* Update rust crate aws-sdk-bedrockruntime to v1.44.0 (#244) by @renovate[bot] - [501321f](https://github.com/bosun-ai/swiftide/commit/501321f811a0eec8d1b367f7c7f33b1dfd29d2b6) *(deps)* Update rust crate spider to v1.99.37 (#230) by @renovate[bot] - [8a1cc69](https://github.com/bosun-ai/swiftide/commit/8a1cc69712b4361893c0564c7d6f7d1ed21e5710) *(query)* After retrieval current transormation should be empty by @timonv ### Miscellaneous - [e9d0016](https://github.com/bosun-ai/swiftide/commit/e9d00160148807a8e2d1df1582e6ea85cfd2d8d0) *(indexing,integrations)* Move tree-sitter dependencies to integrations (#235) by @timonv ````text Removes the dependency of indexing on integrations, resulting in much faster builds when developing on indexing. ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.8.0...v0.9.0 ## [v0.8.0](https://github.com/bosun-ai/swiftide/releases/tag/v0.8.0) - 2024-08-12 ### New features - [2e25ad4](https://github.com/bosun-ai/swiftide/commit/2e25ad4b999a8562a472e086a91020ec4f8300d8) *(indexing)* [**breaking**] Default LLM for indexing pipeline and boilerplate Transformer macro (#227) by @timonv ````text Add setting a default LLM for an indexing pipeline, avoiding the need to clone multiple times. More importantly, introduced `swiftide-macros` with `#[swiftide_macros::indexing_transformer]` that generates all boilerplate code used for internal transformers. This ensures all transformers are consistent and makes them easy to change in the future. This is a big win for maintainability and ease to extend. Users are encouraged to use the macro as well. ```` **BREAKING CHANGE**: Introduces `WithIndexingDefaults` and `WithBatchIndexingDefaults` trait constraints for transformers. They can be used as a marker with a noop (i.e. just `impl WithIndexingDefaults for MyTransformer {}`). However, when implemented fully, they can be used to provide defaults from the pipeline to your transformers. - [67336f1](https://github.com/bosun-ai/swiftide/commit/67336f1d9c7fde474bdddfd0054b40656df244e0) *(indexing)* Sparse vector support with Splade and Qdrant (#222) by @timonv ````text Adds Sparse vector support to the indexing pipeline, enabling hybrid search for vector databases. The design should work for any form of Sparse embedding, and works with existing embedding modes and multiple named vectors. Additionally, added `try_default_sparse` to FastEmbed, using Splade, so it's fully usuable. Hybrid search in the query pipeline coming soon. ```` - [e728a7c](https://github.com/bosun-ai/swiftide/commit/e728a7c7a2fcf7b22c31e5d6c66a896f634f6901) Code outlines in chunk metadata (#137) by @tinco ````text Added a transformer that generates outlines for code files using tree sitter. And another that compresses the outline to be more relevant to chunks. Additionally added a step to the metadata QA tool that uses the outline to improve the contextual awareness during QA generation. ```` ### Bug fixes - [dc7412b](https://github.com/bosun-ai/swiftide/commit/dc7412beda4377e8a6222b3ad576f0a1af332533) *(deps)* Update aws-sdk-rust monorepo (#223) by @renovate[bot] ### Miscellaneous - [9613f50](https://github.com/bosun-ai/swiftide/commit/9613f50c0036b42411cd3a3014f54b592fe4958a) *(ci)* Only show remote github url if present in changelog by @timonv ### Docs - [73d1649](https://github.com/bosun-ai/swiftide/commit/73d1649ca8427aa69170f6451eac55316581ed9a) *(readme)* Add Ollama support to README by @timonv - [b3f04de](https://github.com/bosun-ai/swiftide/commit/b3f04defe94e5b26876c8d99049f4d87b5f2dc18) *(readme)* Add link to discord (#219) by @timonv - [4970a68](https://github.com/bosun-ai/swiftide/commit/4970a683acccc71503e64044dc02addaf2e9c87c) *(readme)* Fix discord links by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.7.1...v0.8.0 ## [v0.7.1](https://github.com/bosun-ai/swiftide/releases/tag/v0.7.1) - 2024-08-04 ### New features - [b2d31e5](https://github.com/bosun-ai/swiftide/commit/b2d31e555cb8da525513490e7603df1f6b2bfa5b) *(integrations)* Add ollama support (#214) by @tinco - [9eb5894](https://github.com/bosun-ai/swiftide/commit/9eb589416c2a56f9942b6f6bed3771cec6acebaf) *(query)* Add support for closures in all steps (#215) by @timonv ### Miscellaneous - [53e662b](https://github.com/bosun-ai/swiftide/commit/53e662b8c30f6ac6d11863685d3850ab48397766) *(ci)* Add cargo deny to lint dependencies (#213) by @timonv ### Docs - [1539393](https://github.com/bosun-ai/swiftide/commit/15393932dd756af134a12f7954faa75893f8c3fb) *(readme)* Update README.md by @timonv - [ba07ab9](https://github.com/bosun-ai/swiftide/commit/ba07ab93722d974ac93ed5d4a22bf53317bc11ae) *(readme)* Readme improvements by @timonv - [f7accde](https://github.com/bosun-ai/swiftide/commit/f7accdeecf01efc291503282554257846725ce57) *(readme)* Add 0.7 announcement by @timonv - [084548f](https://github.com/bosun-ai/swiftide/commit/084548f0fbfbb8cf6d359585f30c8e2593565681) *(readme)* Clarify on closures by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.7.0...v0.7.1 ## [swiftide-v0.7.0](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.7.0) - 2024-07-28 ### New features - [ec1fb04](https://github.com/bosun-ai/swiftide/commit/ec1fb04573ab75fe140cbeff17bc3179e316ff0c) *(indexing)* Metadata as first class citizen (#204) by @timonv ````text Adds our own implementation for metadata, internally still using a BTreeMap. The Value type is now a `serde_json::Value` enum. This allows us to store the metadata in the same format as the rest of the document, and also allows us to use values programmatically later. As is, all current meta data is still stored as Strings. ```` - [16bafe4](https://github.com/bosun-ai/swiftide/commit/16bafe4da8c98adcf90f5bb63070832201c405b9) *(swiftide)* [**breaking**] Rework workspace preparing for swiftide-query (#199) by @timonv ````text Splits up the project into multiple small, unpublished crates. Boosts compile times, makes the code a bit easier to grok and enables swiftide-query to be build separately. ```` **BREAKING CHANGE**: All indexing related tools are now in - [63694d2](https://github.com/bosun-ai/swiftide/commit/63694d2892a7c97a7e7fc42664d550c5acd7bb12) *(swiftide-query)* Query pipeline v1 (#189) by @timonv ### Bug fixes - [ee3aad3](https://github.com/bosun-ai/swiftide/commit/ee3aad37a40eb9f18c9a3082ad6826ff4b6c7245) *(deps)* Update rust crate aws-sdk-bedrockruntime to v1.42.0 (#195) by @renovate[bot] - [be0f31d](https://github.com/bosun-ai/swiftide/commit/be0f31de4f0c7842e23628fd6144cc4406c165c0) *(deps)* Update rust crate spider to v1.99.11 (#190) by @renovate[bot] - [dd04453](https://github.com/bosun-ai/swiftide/commit/dd04453ecb8d04326929780e9e52155b37d731e2) *(swiftide)* Update main lockfile by @timonv - [bafd907](https://github.com/bosun-ai/swiftide/commit/bafd90706346c3e208390f1296f10e2c17ad61b1) Update all cargo package descriptions by @timonv ### Miscellaneous - [e72641b](https://github.com/bosun-ai/swiftide/commit/e72641b677cfd1b21e98fd74552728dbe3e7a9bc) *(ci)* Set versions in dependencies by @timonv ### Docs - [2114aa4](https://github.com/bosun-ai/swiftide/commit/2114aa4394f4eda2e6465e1adb5602ae1b3ff61f) *(readme)* Add copy on the query pipeline by @timonv - [573aff6](https://github.com/bosun-ai/swiftide/commit/573aff6fee3f891bae61e92e131dd15425cefc29) *(indexing)* Document the default prompt templates and their context (#206) by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.6.7...swiftide-v0.7.0 ## [swiftide-v0.6.7](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.6.7) - 2024-07-23 ### New features - [beea449](https://github.com/bosun-ai/swiftide/commit/beea449301b89fde1915c5336a071760c1963c75) *(prompt)* Add Into for strings to PromptTemplate (#193) by @timonv - [f3091f7](https://github.com/bosun-ai/swiftide/commit/f3091f72c74e816f6b9b8aefab058d610becb625) *(transformers)* References and definitions from code (#186) by @timonv ### Docs - [97a572e](https://github.com/bosun-ai/swiftide/commit/97a572ec2e3728bbac82c889bf5129b048e61e0c) *(readme)* Add blog posts and update doc link (#194) by @timonv - [504fe26](https://github.com/bosun-ai/swiftide/commit/504fe2632cf4add506dfb189c17d6e4ecf6f3824) *(pipeline)* Add note that closures can also be used as transformers by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.6.6...swiftide-v0.6.7 ## [swiftide-v0.6.6](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.6.6) - 2024-07-16 ### New features - [d1c642a](https://github.com/bosun-ai/swiftide/commit/d1c642aa4ee9b373e395a78591dd36fa0379a4ff) *(groq)* Add SimplePrompt support for Groq (#183) by @timonv ````text Adds simple prompt support for Groq by using async_openai. ~~Needs some double checks~~. Works great. ```` ### Bug fixes - [5d4a814](https://github.com/bosun-ai/swiftide/commit/5d4a8145b6952b2f4f9a1f144913673eeb3aaf24) *(deps)* Update rust crate aws-sdk-bedrockruntime to v1.40.0 (#169) by @renovate[bot] ### Docs - [143c7c9](https://github.com/bosun-ai/swiftide/commit/143c7c9c2638737166f23f2ef8106b7675f6e19b) *(readme)* Fix typo (#180) by @eltociear - [d393181](https://github.com/bosun-ai/swiftide/commit/d3931818146bff72499ebfcc0d0e8c8bb13a760d) *(docsrs)* Scrape examples and fix links (#184) by @timonv ### New Contributors * @eltociear made their first contribution in [#180](https://github.com/bosun-ai/swiftide/pull/180) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.6.5...swiftide-v0.6.6 ## [swiftide-v0.6.5](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.6.5) - 2024-07-15 ### New features - [0065c7a](https://github.com/bosun-ai/swiftide/commit/0065c7a7fd1289ea227391dd7b9bd51c905290d5) *(prompt)* Add extending the prompt repository (#178) by @timonv ### Bug fixes - [b54691f](https://github.com/bosun-ai/swiftide/commit/b54691f769e2d0ac7886938b6e837551926eea2f) *(prompts)* Include default prompts in crate (#174) by @timonv ````text - **add prompts to crate** - **load prompts via cargo manifest dir** ```` - [3c297bb](https://github.com/bosun-ai/swiftide/commit/3c297bbb85fd3ae9b411a691024f622702da3617) *(swiftide)* Remove include from Cargo.toml by @timonv ### Miscellaneous - [73d5fa3](https://github.com/bosun-ai/swiftide/commit/73d5fa37d23f53919769c2ffe45db2e3832270ef) *(traits)* Cleanup unused batch size in `BatchableTransformer` (#177) by @timonv ### Docs - [b95b395](https://github.com/bosun-ai/swiftide/commit/b95b3955f89ed231cc156dab749ee7bb8be98ee5) *(swiftide)* Documentation improvements and cleanup (#176) by @timonv ````text - **chore: remove ingestion stream** - **Documentation and grammar** ```` **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.6.3...swiftide-v0.6.5 ## [swiftide-v0.6.3](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.6.3) - 2024-07-14 ### Bug fixes - [47418b5](https://github.com/bosun-ai/swiftide/commit/47418b5d729aef1e2ff77dabd7e29b5131512b01) *(prompts)* Fix breaking issue with prompts not found by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.6.2...swiftide-v0.6.3 ## [swiftide-v0.6.2](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.6.2) - 2024-07-12 ### Miscellaneous - [2b682b2](https://github.com/bosun-ai/swiftide/commit/2b682b28fd146fac2c61f1ee430534a04b9fa7ce) *(deps)* Limit feature flags on qdrant to fix docsrs by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.6.1...swiftide-v0.6.2 ## [swiftide-v0.6.1](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.6.1) - 2024-07-12 ### Miscellaneous - [aae7ab1](https://github.com/bosun-ai/swiftide/commit/aae7ab18f8c9509fd19f83695e4eca942c377043) *(deps)* Patch update all by @timonv ### Docs - [085709f](https://github.com/bosun-ai/swiftide/commit/085709fd767bab7153b2222907fc500ad4412570) *(docsrs)* Disable unstable and rustdoc scraping by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.6.0...swiftide-v0.6.1 ## [swiftide-v0.6.0](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.6.0) - 2024-07-12 ### New features - [70ea268](https://github.com/bosun-ai/swiftide/commit/70ea268b19e564af83bb834f56d406a05e02e9cd) *(prompts)* Add prompts as first class citizens (#145) by @timonv ````text Adds Prompts as first class citizens. This is a breaking change as SimplePrompt with just a a `&str` is no longer allowed. This introduces `Prompt` and `PromptTemplate`. A template uses jinja style templating build on tera. Templates can be converted into prompts, and have context added. A prompt is then send to something that prompts, i.e. openai or bedrock. Additional prompts can be added either compiled or as one-offs. Additionally, it's perfectly fine to prompt with just a string as well, just provide an `.into()`. For future development, some LLMs really benefit from system prompts, which this would enable. For the query pipeline we can also take a much more structured approach with composed templates and conditionals. ```` - [699cfe4](https://github.com/bosun-ai/swiftide/commit/699cfe44fb0e3baddba695ad09836caec7cb30a6) Embed modes and named vectors (#123) by @pwalski ````text Added named vector support to qdrant. A pipeline can now have its embed mode configured, either per field, chunk and metadata combined (default) or both. Vectors need to be configured on the qdrant client side. See `examples/store_multiple_vectors.rs` for an example. Shoutout to @pwalski for the contribution. Closes #62. --------- ```` ### Bug fixes - [9334934](https://github.com/bosun-ai/swiftide/commit/9334934e4af92b35dbc61e1f92aa90abac29ca12) *(chunkcode)* Use correct chunksizes (#122) by @timonv - [dfc76dd](https://github.com/bosun-ai/swiftide/commit/dfc76ddfc23d9314fe88c8362bf53d7865a03302) *(deps)* Update rust crate serde to v1.0.204 (#129) by @renovate[bot] - [28f5b04](https://github.com/bosun-ai/swiftide/commit/28f5b048f5acd977915ae20463f8fbb473dfab9a) *(deps)* Update rust crate tree-sitter-typescript to v0.21.2 (#128) by @renovate[bot] - [9c261b8](https://github.com/bosun-ai/swiftide/commit/9c261b87dde2e0caaff0e496d15681466844daf4) *(deps)* Update rust crate text-splitter to v0.14.1 (#127) by @renovate[bot] - [ff92abd](https://github.com/bosun-ai/swiftide/commit/ff92abd95908365c72d96abff37e0284df8fed32) *(deps)* Update rust crate tree-sitter-javascript to v0.21.4 (#126) by @renovate[bot] - [7af97b5](https://github.com/bosun-ai/swiftide/commit/7af97b589ca45f2b966ea2f61ebef341c881f1f9) *(deps)* Update rust crate spider to v1.98.7 (#124) by @renovate[bot] - [adc4bf7](https://github.com/bosun-ai/swiftide/commit/adc4bf789f679079fcc9fac38f4a7b8f98816844) *(deps)* Update aws-sdk-rust monorepo (#125) by @renovate[bot] - [dd32ef3](https://github.com/bosun-ai/swiftide/commit/dd32ef3b1be7cd6888d2961053d0b3c1a882e1a4) *(deps)* Update rust crate async-trait to v0.1.81 (#134) by @renovate[bot] - [2b13523](https://github.com/bosun-ai/swiftide/commit/2b1352322e574b62cb30268b35c6b510122f0584) *(deps)* Update rust crate fastembed to v3.7.1 (#135) by @renovate[bot] - [8e22937](https://github.com/bosun-ai/swiftide/commit/8e22937427b928524dacf2b446feeff726b6a5e1) *(deps)* Update rust crate aws-sdk-bedrockruntime to v1.39.0 (#143) by @renovate[bot] - [353cd9e](https://github.com/bosun-ai/swiftide/commit/353cd9ed36fcf6fb8f1db255d8b5f4a914ca8496) *(qdrant)* Upgrade and better defaults (#118) by @timonv ````text - **fix(deps): update rust crate qdrant-client to v1.10.1** - **fix(qdrant): upgrade to new qdrant with sensible defaults** - **feat(qdrant): safe to clone with internal arc** --------- ```` - [b53636c](https://github.com/bosun-ai/swiftide/commit/b53636cbd8f179f248cc6672aaf658863982c603) Inability to store only some of `EmbeddedField`s (#139) by @pwalski ### Performance - [ea8f823](https://github.com/bosun-ai/swiftide/commit/ea8f8236cdd9c588e55ef78f9eac27db1f13b2d9) Improve local build performance and crate cleanup (#148) by @timonv ````text - **tune cargo for faster builds** - **perf(swiftide): increase local build performance** ```` ### Miscellaneous - [eb8364e](https://github.com/bosun-ai/swiftide/commit/eb8364e08a9202476cca6b60fbdfbb31fe0e1c3d) *(ci)* Try overriding the github repo for git cliff by @timonv - [5de6af4](https://github.com/bosun-ai/swiftide/commit/5de6af42b9a1e95b0fbd54659c0d590db1d76222) *(ci)* Only add contributors if present by @timonv - [4c9ed77](https://github.com/bosun-ai/swiftide/commit/4c9ed77c85b7dd0e8722388b930d169cd2e5a5c7) *(ci)* Properly check if contributors are present by @timonv - [c5bf796](https://github.com/bosun-ai/swiftide/commit/c5bf7960ca6bec498cdc987fe7676acfef702e5b) *(ci)* Add clippy back to ci (#147) by @timonv - [7a8843a](https://github.com/bosun-ai/swiftide/commit/7a8843ab9e64b623870ebe49079ec976aae56d5c) *(deps)* Update rust crate testcontainers to 0.20.0 (#133) by @renovate[bot] - [364e13d](https://github.com/bosun-ai/swiftide/commit/364e13d83285317a1fb99889f6d74ad32b58c482) *(swiftide)* Loosen up dependencies (#140) by @timonv ````text Loosen up dependencies so swiftide is a bit more flexible to add to existing projects ```` - [84dd65d](https://github.com/bosun-ai/swiftide/commit/84dd65dc6c0ff4595f27ed061a4f4c0a2dae7202) [**breaking**] Rename all mentions of ingest to index (#130) by @timonv ````text Swiftide is not an ingestion pipeline (loading data), but an indexing pipeline (prepping for search). There is now a temporary, deprecated re-export to match the previous api. ```` **BREAKING CHANGE**: rename all mentions of ingest to index (#130) - [51c114c](https://github.com/bosun-ai/swiftide/commit/51c114ceb06db840c4952d3d0f694bfbf266681c) Various tooling & community improvements (#131) by @timonv ````text - **fix(ci): ensure clippy runs with all features** - **chore(ci): coverage using llvm-cov** - **chore: drastically improve changelog generation** - **chore(ci): add sanity checks for pull requests** - **chore(ci): split jobs and add typos** ```` - [d2a9ea1](https://github.com/bosun-ai/swiftide/commit/d2a9ea1e7afa6f192bf9c32bbb54d9bb6e46472e) Enable clippy pedantic (#132) by @timonv ### Docs - [8405c9e](https://github.com/bosun-ai/swiftide/commit/8405c9efedef944156c2904eb709ba79aa4d82de) *(contributing)* Add guidelines on code design (#113) by @timonv - [3e447fe](https://github.com/bosun-ai/swiftide/commit/3e447feab83a4bf8d7d9d8220fe1b92dede9af79) *(readme)* Link to CONTRIBUTING (#114) by @timonv - [4c40e27](https://github.com/bosun-ai/swiftide/commit/4c40e27e5c6735305c70696ddf71dd5f95d03bbb) *(readme)* Add back coverage badge by @timonv - [5691ac9](https://github.com/bosun-ai/swiftide/commit/5691ac930fd6547c3f0166b64ead0ae647c38883) *(readme)* Add preproduction warning by @timonv - [37af322](https://github.com/bosun-ai/swiftide/commit/37af3225b4c3464aa4ed67f8f456c26f3d445507) *(rustdocs)* Rewrite the initial landing page (#149) by @timonv ````text - **Add homepage and badges to cargo toml** - **documentation landing page improvements** ```` - [7686c2d](https://github.com/bosun-ai/swiftide/commit/7686c2d449b5df0fddc08b111174357d47459f86) Templated prompts are now a major feature by @timonv ### New Contributors * @pwalski made their first contribution in [#139](https://github.com/bosun-ai/swiftide/pull/139) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.5.0...swiftide-v0.6.0 ## [swiftide-v0.5.0](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.5.0) - 2024-07-01 ### New features - [6a88651](https://github.com/bosun-ai/swiftide/commit/6a88651df8c6b91add03acfc071fb9479545b8af) *(ingestion_pipeline)* Implement filter (#109) by @timonv - [5aeb3a7](https://github.com/bosun-ai/swiftide/commit/5aeb3a7fb75b21b2f24b111e9640ea4985b2e316) *(ingestion_pipeline)* Splitting and merging streams by @timonv - [8812fbf](https://github.com/bosun-ai/swiftide/commit/8812fbf30b882b68bf25f3d56b3ddf17af0bcb7a) *(ingestion_pipeline)* Build a pipeline from a stream by @timonv - [6101bed](https://github.com/bosun-ai/swiftide/commit/6101bed812c5167eb87a4093d66005140517598d) AWS bedrock support (#92) by @timonv ````text Adds an integration with AWS Bedrock, implementing SimplePrompt for Anthropic and Titan models. More can be added if there is a need. Same for the embedding models. ```` ### Bug fixes - [17a2be1](https://github.com/bosun-ai/swiftide/commit/17a2be1de6c0f3bda137501db4b1703f9ed0b1c5) *(changelog)* Add scope by @timonv - [a12cce2](https://github.com/bosun-ai/swiftide/commit/a12cce230032eebe2f7ff1aa9cdc85b8fc200eb1) *(openai)* Add tests for builder by @timonv - [963919b](https://github.com/bosun-ai/swiftide/commit/963919b0947faeb7d96931c19e524453ad4a0007) *(transformers)* [**breaking**] Fix too small chunks being retained and api by @timonv **BREAKING CHANGE**: Fix too small chunks being retained and api - [5e8da00](https://github.com/bosun-ai/swiftide/commit/5e8da008ce08a23377672a046a4cedd48d4cf30c) Fix oversight in ingestion pipeline tests by @timonv - [e8198d8](https://github.com/bosun-ai/swiftide/commit/e8198d81354bbca2c21ca08b9522d02b8c93173b) Use git cliff manually for changelog generation by @timonv - [2c31513](https://github.com/bosun-ai/swiftide/commit/2c31513a0ded87addd0519bbfdd63b5abed29f73) Just use keepachangelog by @timonv - [6430af7](https://github.com/bosun-ai/swiftide/commit/6430af7b57eecb7fdb954cd89ade4547b8e92dbd) Use native cargo bench format and only run benchmarks crate by @timonv - [cba981a](https://github.com/bosun-ai/swiftide/commit/cba981a317a80173eff2946fc551d1a36ec40f65) Replace unwrap with expect and add comment on panic by @timonv ### Miscellaneous - [e243212](https://github.com/bosun-ai/swiftide/commit/e2432123f0dfc48147ebed13fe6e3efec3ff7b3f) *(ci)* Enable continous benchmarking and improve benchmarks (#98) by @timonv - [2dbf14c](https://github.com/bosun-ai/swiftide/commit/2dbf14c34bed2ee40ab79c0a46d011cd20882bda) *(ci)* Fix benchmarks in ci by @timonv - [b155de6](https://github.com/bosun-ai/swiftide/commit/b155de6387ddfe64d1a177b31c8e1ed93739b2c9) *(ci)* Fix naming of github actions by @timonv - [206e432](https://github.com/bosun-ai/swiftide/commit/206e432dd291dd6a4592a6fb5f890049595311cb) *(ci)* Add support for merge queues by @timonv - [46752db](https://github.com/bosun-ai/swiftide/commit/46752dbfc8ccd578ddba915fd6cd6509e3e6fb14) *(ci)* Add concurrency configuration by @timonv - [5f09c11](https://github.com/bosun-ai/swiftide/commit/5f09c116f418cecb96fb1e86161333908d1a4d70) Add initial benchmarks by @timonv - [162c6ef](https://github.com/bosun-ai/swiftide/commit/162c6ef2a07e40b8607b0ab6773909521f0bb798) Ensure feat is always in Added by @timonv ### Docs - [929410c](https://github.com/bosun-ai/swiftide/commit/929410cb1c2d81b6ffaec4c948c891472835429d) *(readme)* Add diagram to the readme (#107) by @timonv - [b014f43](https://github.com/bosun-ai/swiftide/commit/b014f43aa187881160245b4356f95afe2c6fe98c) Improve documentation across the project (#112) by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.4.3...swiftide-v0.5.0 ## [swiftide-v0.4.3](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.4.3) - 2024-06-28 ### Bug fixes - [ab3dc86](https://github.com/bosun-ai/swiftide/commit/ab3dc861490a0d1ab94f96e741e09c860094ebc0) *(memory_storage)* Fallback to incremental counter when missing id by @timonv ### Miscellaneous - [bdebc24](https://github.com/bosun-ai/swiftide/commit/bdebc241507e9f55998e96ca4aece530363716af) Clippy by @timonv ### Docs - [dad3e02](https://github.com/bosun-ai/swiftide/commit/dad3e02fdc8a57e9de16832090c44c536e7e394b) *(readme)* Add ci badge by @timonv - [4076092](https://github.com/bosun-ai/swiftide/commit/40760929d24e20631d0552d87bdbb4fdf9195453) *(readme)* Clean up and consistent badge styles by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.4.2...swiftide-v0.4.3 ## [swiftide-v0.4.2](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.4.2) - 2024-06-26 ### New features - [926cc0c](https://github.com/bosun-ai/swiftide/commit/926cc0cca46023bcc3097a97b10ce03ae1fc3cc2) *(ingestion_stream)* Implement into for Result> by @timonv ### Bug fixes - [3143308](https://github.com/bosun-ai/swiftide/commit/3143308136ec4e71c8a5f9a127119e475329c1a2) *(embed)* Panic if number of embeddings and node are equal by @timonv ### Miscellaneous - [5ed08bb](https://github.com/bosun-ai/swiftide/commit/5ed08bb259b7544d3e4f2acdeef56231aa32e17c) Cleanup changelog by @timonv ### Docs - [47aa378](https://github.com/bosun-ai/swiftide/commit/47aa378c4a70c47a2b313b6eca8dcf02b4723963) Create CONTRIBUTING.md by @timonv - [0660d5b](https://github.com/bosun-ai/swiftide/commit/0660d5b08aed15d62f077363eae80f621ddaa510) Readme updates by @timonv ### Refactor - [d285874](https://github.com/bosun-ai/swiftide/commit/d28587448d7fe342a79ac687cd5d7ee27354cae6) *(ingestion_pipeline)* Log_all combines other log helpers by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.4.1...swiftide-v0.4.2 ## [swiftide-v0.4.1](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.4.1) - 2024-06-24 ### New features - [3898ee7](https://github.com/bosun-ai/swiftide/commit/3898ee7d6273ee7034848f9ab08fd85613cb5b32) *(memory_storage)* Can be cloned safely preserving storage by @timonv - [92052bf](https://github.com/bosun-ai/swiftide/commit/92052bfdbca8951620f6d016768d252e793ecb5d) *(transformers)* Allow for arbitrary closures as transformers and batchable transformers by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.4.0...swiftide-v0.4.1 ## [swiftide-v0.4.0](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.4.0) - 2024-06-23 ### New features - [477a284](https://github.com/bosun-ai/swiftide/commit/477a284597359472988ecde372e080f60aab0804) *(benchmarks)* Add benchmark for the file loader by @timonv - [1567940](https://github.com/bosun-ai/swiftide/commit/15679409032e9be347fbe8838a308ff0d09768b8) *(benchmarks)* Add benchmark for simple local pipeline by @timonv - [2228d84](https://github.com/bosun-ai/swiftide/commit/2228d84ccaad491e2c3cd0feb948050ad2872cf0) *(examples)* Example for markdown with all metadata by @timonv - [9a1e12d](https://github.com/bosun-ai/swiftide/commit/9a1e12d34e02fe2292ce679251b96d61be74c884) *(examples,scraping)* Add example scraping and ingesting a url by @timonv - [15deeb7](https://github.com/bosun-ai/swiftide/commit/15deeb72ca2e131e8554fa9cbefa3ef369de752a) *(ingestion_node)* Add constructor with defaults by @timonv - [4d5c68e](https://github.com/bosun-ai/swiftide/commit/4d5c68e7bb09fae18832e2a453f114df5ba32ce1) *(ingestion_node)* Improved human readable Debug by @timonv - [a5051b7](https://github.com/bosun-ai/swiftide/commit/a5051b79b2ce62d41dd93f7b34a1a065d9878732) *(ingestion_pipeline)* Optional error filtering and logging (#75) by @timonv - [062107b](https://github.com/bosun-ai/swiftide/commit/062107b46474766640c38266f6fd6c27a95d4b57) *(ingestion_pipeline)* Implement throttling a pipeline (#77) by @timonv - [a2ffc78](https://github.com/bosun-ai/swiftide/commit/a2ffc78f6d25769b9b7894f1f0703d51242023d4) *(ingestion_stream)* Improved stream developer experience (#81) by @timonv ````text Improves stream ergonomics by providing convenient helpers and `Into` for streams, vectors and iterators that match the internal type. This means that in many cases, trait implementers can simply call `.into()` instead of manually constructing a stream. In the case it's an iterator, they can now use `IngestionStream::iter()` instead. ```` - [d260674](https://github.com/bosun-ai/swiftide/commit/d2606745de8b22dcdf02e244d1b044efe12c6ac7) *(integrations)* [**breaking**] Support fastembed (#60) by @timonv ````text Adds support for FastEmbed with various models. Includes a breaking change, renaming the Embed trait to EmbeddingModel. ```` **BREAKING CHANGE**: support fastembed (#60) - [9004323](https://github.com/bosun-ai/swiftide/commit/9004323dc5b11a3556a47e11fb8912ffc49f1e9e) *(integrations)* [**breaking**] Implement Persist for Redis (#80) by @timonv **BREAKING CHANGE**: implement Persist for Redis (#80) - [eb84dd2](https://github.com/bosun-ai/swiftide/commit/eb84dd27c61a1b3a4a52a53cc0404203eac729e8) *(integrations,transformers)* Add transformer for converting html to markdown by @timonv - [ef7dcea](https://github.com/bosun-ai/swiftide/commit/ef7dcea45bfc336e7defcaac36bb5a6ff27d5acd) *(loaders)* File loader performance improvements by @timonv - [6d37051](https://github.com/bosun-ai/swiftide/commit/6d37051a9c2ef24ea7eb3815efcf9692df0d70ce) *(loaders)* Add scraping using `spider` by @timonv - [2351867](https://github.com/bosun-ai/swiftide/commit/235186707182e8c39b8f22c6dd9d54eb32f7d1e5) *(persist)* In memory storage for testing, experimentation and debugging by @timonv - [4d5d650](https://github.com/bosun-ai/swiftide/commit/4d5d650f235395aa81816637d559de39853e1db1) *(traits)* Add automock for simpleprompt by @timonv - [bd6f887](https://github.com/bosun-ai/swiftide/commit/bd6f8876d010d23f651fd26a48d6775c17c98e94) *(transformers)* Add transformers for title, summary and keywords by @timonv ### Bug fixes - [7cbfc4e](https://github.com/bosun-ai/swiftide/commit/7cbfc4e13745ee5a6776a97fc6db06608fae8e81) *(ingestion_pipeline)* Concurrency does not work when spawned (#76) by @timonv ````text Currency does did not work as expected. When spawning via `Tokio::spawn` the future would be polled directly, and any concurrency setting would not be respected. Because it had to be removed, improved tracing for each step as well. ```` ### Miscellaneous - [f4341ba](https://github.com/bosun-ai/swiftide/commit/f4341babe5807b268ce86a88e0df4bfc6d756de4) *(ci)* Single changelog for all (future) crates in root (#57) by @timonv - [7dde8a0](https://github.com/bosun-ai/swiftide/commit/7dde8a0811c7504b807b3ef9f508ce4be24967b8) *(ci)* Code coverage reporting (#58) by @timonv ````text Post test coverage to Coveralls Also enabled --all-features when running tests in ci, just to be sure ```` - [cb7a2cd](https://github.com/bosun-ai/swiftide/commit/cb7a2cd3a72f306a0b46556caee0a25c7ba2c0e0) *(scraping)* Exclude spider from test coverage by @timonv - [7767588](https://github.com/bosun-ai/swiftide/commit/77675884a2eeb0aab6ce57dccd2a260f5a973197) *(transformers)* Improve test coverage by @timonv - [3b7c0db](https://github.com/bosun-ai/swiftide/commit/3b7c0dbc2f020ce84a5da5691ee6eb415df2d466) Move changelog to root by @timonv - [d6d0215](https://github.com/bosun-ai/swiftide/commit/d6d021560a05508add07a72f4f438d3ea3f1cb2c) Properly quote crate name in changelog by @timonv - [f251895](https://github.com/bosun-ai/swiftide/commit/f2518950427ef758fd57e6e6189ce600adf19940) Documentation and feature flag cleanup (#69) by @timonv ````text With fastembed added our dependencies become rather heavy. By default now disable all integrations and either provide 'all' or cherry pick integrations. ```` - [f6656be](https://github.com/bosun-ai/swiftide/commit/f6656becd199762843a59b0f86871753360a08f0) Cargo update by @timonv ### Docs - [53ed920](https://github.com/bosun-ai/swiftide/commit/53ed9206835da1172295e296119ee9a883605f18) Hide the table of contents by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.3.3...swiftide-v0.4.0 ## [swiftide-v0.3.3](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.3.3) - 2024-06-16 ### New features - [bdaed53](https://github.com/bosun-ai/swiftide/commit/bdaed5334b3e122f803370cc688dd2f662db0b8d) *(integrations)* Clone and debug for integrations by @timonv - [318e538](https://github.com/bosun-ai/swiftide/commit/318e538acb30ca516a780b5cc42c8ab2ed91cd6b) *(transformers)* Builder and clone for chunk_code by @timonv - [c074cc0](https://github.com/bosun-ai/swiftide/commit/c074cc0edb8b0314de15f9a096699e3e744c9f33) *(transformers)* Builder for chunk_markdown by @timonv - [e18e7fa](https://github.com/bosun-ai/swiftide/commit/e18e7fafae3007f1980bb617b7a72dd605720d74) *(transformers)* Builder and clone for MetadataQACode by @timonv - [fd63dff](https://github.com/bosun-ai/swiftide/commit/fd63dffb4f0b11bb9fa4fadc7b076463eca111a6) *(transformers)* Builder and clone for MetadataQAText by @timonv ### Miscellaneous - [678106c](https://github.com/bosun-ai/swiftide/commit/678106c01b7791311a24425c22ea39366b664033) *(ci)* Pretty names for pipelines (#54) by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.3.2...swiftide-v0.3.3 ## [swiftide-v0.3.2](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.3.2) - 2024-06-16 ### New features - [b211002](https://github.com/bosun-ai/swiftide/commit/b211002e40ef16ef240e142c0178b04636a4f9aa) *(integrations)* Qdrant and openai builder should be consistent (#52) by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.3.1...swiftide-v0.3.2 ## [swiftide-v0.3.1](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.3.1) - 2024-06-15 ### Docs - [6f63866](https://github.com/bosun-ai/swiftide/commit/6f6386693f3f6e0328eedaa4fb69cd8d0694574b) We love feedback <3 by @timonv - [7d79b64](https://github.com/bosun-ai/swiftide/commit/7d79b645d2e4f7da05b4c9952a1ceb79583572b3) Fixing some grammar typos on README.md (#51) by @hectorip ### New Contributors * @hectorip made their first contribution in [#51](https://github.com/bosun-ai/swiftide/pull/51) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.3.0...swiftide-v0.3.1 ## [swiftide-v0.3.0](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.3.0) - 2024-06-14 ### New features - [745b8ed](https://github.com/bosun-ai/swiftide/commit/745b8ed7e58f76e415501e6219ecec65551d1897) *(ingestion_pipeline)* [**breaking**] Support chained storage backends (#46) by @timonv ````text Pipeline now supports multiple storage backends. This makes the order of adding storage important. Changed the name of the method to reflect that. ```` **BREAKING CHANGE**: support chained storage backends (#46) - [cd055f1](https://github.com/bosun-ai/swiftide/commit/cd055f19096daa802fe7fc34763bfdfd87c1ec41) *(ingestion_pipeline)* Concurrency improvements (#48) by @timonv - [1f0cd28](https://github.com/bosun-ai/swiftide/commit/1f0cd28ce4c02a39dbab7dd3c3f789798644daa3) *(ingestion_pipeline)* Early return if any error encountered (#49) by @timonv - [fa74939](https://github.com/bosun-ai/swiftide/commit/fa74939b30bd31301e3f80c407f153b5d96aa007) Configurable concurrency for transformers and chunkers (#47) by @timonv ### Docs - [473e60e](https://github.com/bosun-ai/swiftide/commit/473e60ecf9356e2fcabe68245f8bb8be7373cdfb) Update linkedin link by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.2.1...swiftide-v0.3.0 ## [swiftide-v0.2.1](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.2.1) - 2024-06-13 ### Docs - [cb9b4fe](https://github.com/bosun-ai/swiftide/commit/cb9b4feec1c3654f5067f9478b1a7cf59040a9fe) Add link to bosun by @timonv - [e330ab9](https://github.com/bosun-ai/swiftide/commit/e330ab92d7e8d3f806280fa781f0e1b179d9b900) Fix documentation link by @timonv **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/swiftide-v0.2.0...swiftide-v0.2.1 ## [swiftide-v0.2.0](https://github.com/bosun-ai/swiftide/releases/tag/swiftide-v0.2.0) - 2024-06-13 ### New features - [9ec93be](https://github.com/bosun-ai/swiftide/commit/9ec93be110bd047c7e276714c48df236b1a235d7) Api improvements with example (#10) by @timonv ### Bug fixes - [42f8008](https://github.com/bosun-ai/swiftide/commit/42f80086042c659aef74ddd0ea1463c84650938d) Clippy & fmt by @timonv - [5b7ffd7](https://github.com/bosun-ai/swiftide/commit/5b7ffd7368a2688f70892fe37f28c0baea7ad54f) Fmt by @timonv ### Docs - [95a6200](https://github.com/bosun-ai/swiftide/commit/95a62008be1869e581ecaa0586a48cfbb6a7606a) *(swiftide)* Documented file swiftide/src/ingestion/ingestion_pipeline.rs (#14) by @bosun-ai[bot] - [7abccc2](https://github.com/bosun-ai/swiftide/commit/7abccc2af890c8369a2b46940f35274080b3cb61) *(swiftide)* Documented file swiftide/src/ingestion/ingestion_stream.rs (#16) by @bosun-ai[bot] - [755cd47](https://github.com/bosun-ai/swiftide/commit/755cd47ad00e562818162cf78e6df0c5daa99d14) *(swiftide)* Documented file swiftide/src/ingestion/ingestion_node.rs (#15) by @bosun-ai[bot] - [2ea5a84](https://github.com/bosun-ai/swiftide/commit/2ea5a8445c8df7ef36e5fbc25f13c870e5a4dfd5) *(swiftide)* Documented file swiftide/src/integrations/openai/mod.rs (#21) by @bosun-ai[bot] - [b319c0d](https://github.com/bosun-ai/swiftide/commit/b319c0d484db65d3a4594347e70770b8fac39e10) *(swiftide)* Documented file swiftide/src/integrations/treesitter/splitter.rs (#30) by @bosun-ai[bot] - [29fce74](https://github.com/bosun-ai/swiftide/commit/29fce7437042f1f287987011825b57c58c180696) *(swiftide)* Documented file swiftide/src/integrations/redis/node_cache.rs (#29) by @bosun-ai[bot] - [7229af8](https://github.com/bosun-ai/swiftide/commit/7229af8535daa450ebafd6c45c322222a2dd12a0) *(swiftide)* Documented file swiftide/src/integrations/qdrant/persist.rs (#24) by @bosun-ai[bot] - [6240a26](https://github.com/bosun-ai/swiftide/commit/6240a260b582034970d2ee46da9f5234cf317820) *(swiftide)* Documented file swiftide/src/integrations/redis/mod.rs (#23) by @bosun-ai[bot] - [7688c99](https://github.com/bosun-ai/swiftide/commit/7688c993125a129204739fc7cd8d23d0ebfc9022) *(swiftide)* Documented file swiftide/src/integrations/qdrant/mod.rs (#22) by @bosun-ai[bot] - [d572c88](https://github.com/bosun-ai/swiftide/commit/d572c88f2b4cfc4bbdd7bd5ca93f7fd8460f1cb0) *(swiftide)* Documented file swiftide/src/integrations/qdrant/ingestion_node.rs (#20) by @bosun-ai[bot] - [14e24c3](https://github.com/bosun-ai/swiftide/commit/14e24c30d28dc6272a5eb8275e758a2a989d66be) *(swiftide)* Documented file swiftide/src/ingestion/mod.rs (#28) by @bosun-ai[bot] - [502939f](https://github.com/bosun-ai/swiftide/commit/502939fcb5f56b7549b97bb99d4d121bf030835f) *(swiftide)* Documented file swiftide/src/integrations/treesitter/supported_languages.rs (#26) by @bosun-ai[bot] - [a78e68e](https://github.com/bosun-ai/swiftide/commit/a78e68e347dc3791957eeaf0f0adc050aeac1741) *(swiftide)* Documented file swiftide/tests/ingestion_pipeline.rs (#41) by @bosun-ai[bot] - [289687e](https://github.com/bosun-ai/swiftide/commit/289687e1a6c0a9555a6cbecb24951522529f9e1a) *(swiftide)* Documented file swiftide/src/loaders/mod.rs (#40) by @bosun-ai[bot] - [ebd0a5d](https://github.com/bosun-ai/swiftide/commit/ebd0a5dda940c5ef8c2b795ee8ab56e468726869) *(swiftide)* Documented file swiftide/src/transformers/chunk_code.rs (#39) by @bosun-ai[bot] - [fb428d1](https://github.com/bosun-ai/swiftide/commit/fb428d1e250eded80d4edc8ccc0c9a9b840fc065) *(swiftide)* Documented file swiftide/src/transformers/metadata_qa_text.rs (#36) by @bosun-ai[bot] - [305a641](https://github.com/bosun-ai/swiftide/commit/305a64149f015539823d748915e42ad440a7b4b4) *(swiftide)* Documented file swiftide/src/transformers/openai_embed.rs (#35) by @bosun-ai[bot] - [c932897](https://github.com/bosun-ai/swiftide/commit/c93289740806d9283ba488dd640dad5e4339e07d) *(swiftide)* Documented file swiftide/src/transformers/metadata_qa_code.rs (#34) by @bosun-ai[bot] - [090ef1b](https://github.com/bosun-ai/swiftide/commit/090ef1b38684afca8dbcbfe31a8debc2328042e5) *(swiftide)* Documented file swiftide/src/integrations/openai/simple_prompt.rs (#19) by @bosun-ai[bot] - [7cfcc83](https://github.com/bosun-ai/swiftide/commit/7cfcc83eec29d8bed44172b497d4468b0b67d293) Update readme template links and fix template by @timonv - [a717f3d](https://github.com/bosun-ai/swiftide/commit/a717f3d5a68d9c79f9b8d85d8cb8979100dc3949) Template links should be underscores by @timonv ### New Contributors * @bosun-ai[bot] made their first contribution in [#19](https://github.com/bosun-ai/swiftide/pull/19) **Full Changelog**: https://github.com/bosun-ai/swiftide/compare/v0.1.0...swiftide-v0.2.0 ## [v0.1.0](https://github.com/bosun-ai/swiftide/releases/tag/v0.1.0) - 2024-06-13 ### New features - [2a6e503](https://github.com/bosun-ai/swiftide/commit/2a6e503e8abdab83ead7b8e62f39e222fa9f45d1) *(doc)* Setup basic readme (#5) by @timonv - [b8f9166](https://github.com/bosun-ai/swiftide/commit/b8f9166e1d5419cf0d2cc6b6f0b2378241850574) *(fluyt)* Significant tracing improvements (#368) by @timonv ````text * fix(fluyt): remove unnecessary cloning and unwraps * fix(fluyt): also set target correctly on manual spans * fix(fluyt): do not capture raw result * feat(fluyt): nicer tracing for ingestion pipeline * fix(fluyt): remove instrumentation on lazy methods * feat(fluyt): add useful metadata to the root span * fix(fluyt): fix dangling spans in ingestion pipeline * fix(fluyt): do not log codebase in rag utils ```` - [0986136](https://github.com/bosun-ai/swiftide/commit/098613622a7018318f2fffe0d51cd17822bf2313) *(fluyt/code_ops)* Add languages to chunker and range for chunk size (#334) by @timonv ````text * feat(fluyt/code_ops): add more treesitter languages * fix: clippy + fmt * feat(fluyt/code_ops): implement builder and support range * feat(fluyt/code_ops): implement range limits for code chunking * feat(fluyt/indexing): code chunking supports size ```` - [f10bc30](https://github.com/bosun-ai/swiftide/commit/f10bc304b0b2e28281c90e57b6613c274dc20727) *(ingestion_pipeline)* Default concurrency is the number of cpus (#6) by @timonv - [7453ddc](https://github.com/bosun-ai/swiftide/commit/7453ddc387feb17906ae851a17695f4c8232ee19) Replace databuoy with new ingestion pipeline (#322) by @timonv - [054b560](https://github.com/bosun-ai/swiftide/commit/054b560571b4a4398a551837536fb8fbff13c149) Fix build and add feature flags for all integrations by @timonv ### Bug fixes - [fdf4be3](https://github.com/bosun-ai/swiftide/commit/fdf4be3d0967229a9dd84f568b0697fea4ddd341) *(fluyt)* Ensure minimal tracing by @timonv - [389b0f1](https://github.com/bosun-ai/swiftide/commit/389b0f12039f29703bc8bb71919b8067fadf5a8e) Add debug info to qdrant setup by @timonv - [bb905a3](https://github.com/bosun-ai/swiftide/commit/bb905a30d871ea3b238c3bc5cfd1d96724c8d4eb) Use rustls on redis and log errors by @timonv - [458801c](https://github.com/bosun-ai/swiftide/commit/458801c16f9111c1070878c3a82a319701ae379c) Properly connect to redis over tls by @timonv ### Miscellaneous - [ce6e465](https://github.com/bosun-ai/swiftide/commit/ce6e465d4fb12e2bbc7547738b5fbe5133ec2d5a) *(fluyt)* Add verbose log on checking if index exists by @timonv - [6967b0d](https://github.com/bosun-ai/swiftide/commit/6967b0d5b6221f7620161969865fb31959fc93b8) Make indexing extraction compile by @tinco - [f595f3d](https://github.com/bosun-ai/swiftide/commit/f595f3dae88bb4da5f4bbf6c5fe4f04abb4b7db3) Add rust-toolchain on stable by @timonv - [da004c6](https://github.com/bosun-ai/swiftide/commit/da004c6fcf82579c3c75414cb9f04f02530e2e31) Start cleaning up dependencies by @timonv - [cccdaf5](https://github.com/bosun-ai/swiftide/commit/cccdaf567744d58e0ee8ffcc8636f3b35090778f) Remove more unused dependencies by @timonv - [7ee8799](https://github.com/bosun-ai/swiftide/commit/7ee8799aeccc56fb0c14dbe68a7126cabfb40dd3) Remove more crates and update by @timonv - [951f496](https://github.com/bosun-ai/swiftide/commit/951f496498b35f7687fb556e5bf7f931a662ff8a) Clean up more crates by @timonv - [1f17d84](https://github.com/bosun-ai/swiftide/commit/1f17d84cc218602a480b27974f23f64c4269134f) Cargo update by @timonv - [730d879](https://github.com/bosun-ai/swiftide/commit/730d879e76c867c2097aef83bbbfa1211a053bdc) Create LICENSE by @timonv - [44524fb](https://github.com/bosun-ai/swiftide/commit/44524fb51523291b9137fbdcaff9133a9a80c58a) Restructure repository and rename (#3) by @timonv ````text * chore: move traits around * chore: move crates to root folder * chore: restructure and make it compile * chore: remove infrastructure * fix: make it compile * fix: clippy * chore: remove min rust version * chore: cargo update * chore: remove code_ops * chore: settle on swiftide ```` - [e717b7f](https://github.com/bosun-ai/swiftide/commit/e717b7f0b1311b11ed4690e7e11d9fdf53d4a81b) Update issue templates by @timonv - [8e22e0e](https://github.com/bosun-ai/swiftide/commit/8e22e0ef82fffa4f907b0e2cccd1c4e010ffbd01) Cleanup by @timonv - [4d79d27](https://github.com/bosun-ai/swiftide/commit/4d79d27709e3fed32c1b1f2c1f8dbeae1721d714) Tests, tests, tests (#4) by @timonv - [1036d56](https://github.com/bosun-ai/swiftide/commit/1036d565d8d9740ab55995095d495e582ce643d8) Configure cargo toml (#7) by @timonv - [0ae98a7](https://github.com/bosun-ai/swiftide/commit/0ae98a772a751ddc60dd1d8e1606f9bdab4e04fd) Cleanup Cargo keywords by @timonv ### Refactor - [0d342ea](https://github.com/bosun-ai/swiftide/commit/0d342eab747bc5f44adaa5b6131c30c09b1172a2) Models as first class citizens (#318) by @timonv ````text * refactor: refactor common datastructures to /models * refactor: promote to first class citizens * fix: clippy * fix: remove duplication in http handler * fix: clippy * fix: fmt * feat: update for latest change * fix(fluyt/models): doctest ```` ================================================ FILE: CONTRIBUTING.md ================================================ # Contribution guidelines Swiftide is in a very early stage and we are aware that we do lack features for the wider community. Contributions are very welcome. :tada: Indexing and querying are performance sensitive tasks. Please make sure to consider allocations and performance when contributing. AI Generated code is welcome and not frowned upon. Please be genuine and think critically about what you add. For AI agents read the [AGENTS.md](AGENTS.md) for workspace layout, commands, and expectations tailored to agents. ## Feature requests and feedback We love them, please let us know what you would like. Use one of the templates provided. ## Code design * Simple, thin wrappers with sane defaults * Provide a builder (derive_builder) for easy customization * Keep Rust complexity (Arc/Box/Lifetimes/Pinning ...) encapsulated and away from library users * Adhere to [Rust api naming](https://rust-lang.github.io/api-guidelines/naming.html) as much as possible ## Bug reports It happens, but we still love them. ## Submitting pull requests If you have a great idea, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". Don't forget to give the project a star! Thanks again! If you just want to contribute (bless you!), see [our issues](https://github.com/bosun-ai/swiftide/issues). 1. Fork the Project 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 3. Commit your Changes (`git commit -m 'feat: Add some AmazingFeature'`) 4. Push to the Branch (`git push origin feature/AmazingFeature`) 5. Open a Pull Request Make sure that: * Public functions are documented in code * Documentation is updated in the [user documentation](https://github.com/bosun-ai/swiftide-website) * Tests are added * Verified performance with benchmarks if applicable ================================================ FILE: Cargo.toml ================================================ cargo-features = ["edition2024"] [workspace] members = ["swiftide", "swiftide-*", "examples", "benchmarks"] default-members = ["swiftide", "swiftide-*"] resolver = "2" [workspace.package] version = "0.32.1" edition = "2024" license = "MIT" readme = "README.md" keywords = ["llm", "rag", "ai", "data", "openai"] description = "Fast, streaming indexing, query, and agentic LLM applications in Rust" categories = ["asynchronous"] repository = "https://github.com/bosun-ai/swiftide" homepage = "https://swiftide.rs" [profile.dev.package] insta.opt-level = 3 similar.opt-level = 3 [workspace.dependencies] anyhow = { version = "1.0", default-features = false } thiserror = { version = "2.0", default-features = false } async-trait = { version = "0.1", default-features = false } derive_builder = { version = "0.20", default-features = true } fs-err = { version = "3.1", default-features = false } futures-util = { version = "0.3", default-features = true } tokio = { version = "1.46", features = [ "rt-multi-thread", "time", ], default-features = false } tokio-stream = { version = "0.1", default-features = false, features = [ "time", ] } tokio-util = { version = "0.7", default-features = false } tracing = { version = "0.1", features = [ "log", "attributes", ], default-features = false } num_cpus = { version = "1.17", default-features = false } pin-project = { version = "1.1", default-features = false } itertools = { version = "0.14", default-features = true } serde = { version = "1.0", features = [ "derive", "std", ], default-features = false } serde_json = { version = "1.0", default-features = false, features = ["std"] } strum = { version = "0.28", default-features = false } strum_macros = { version = "0.28", default-features = false } lazy_static = { version = "1.5", default-features = false } chrono = { version = "0.4", default-features = false } indoc = { version = "2.0", default-features = false } regex = { version = "1.11", default-features = false } uuid = { version = "1.18", features = [ "v3", "v4", "serde", ], default-features = false } dyn-clone = { version = "1.0", default-features = false } convert_case = { version = "0.11", default-features = false } base64 = { version = "0.22", default-features = false, features = ["std"] } # Mcp rmcp = { version = "0.17", default-features = false, features = [ "base64", "macros", "server", ] } schemars = { version = "1.0", default-features = false } # Integrations spider = { version = "2.45", default-features = false } async-openai = { version = ">=0.33.0", default-features = false } qdrant-client = { version = "1.17", default-features = false, features = [ "serde", ] } fluvio = { version = "0.50.1", default-features = false } rdkafka = { version = "0.39.0", features = ["cmake-build"] } lancedb = { version = "0.26", default-features = false, features = ["remote"] } # Needs to stay in sync with whatever lancdb uses, nice arrow-array = { version = "57.1", default-features = false } parquet = { version = "57.1", default-features = false, features = ["async"] } redb = { version = "3.1", default-features = false } sqlx = { version = "0.8", features = [ "postgres", "uuid", ], default-features = false } aws-config = { version = "1.8", default-features = true } pgvector = { version = "0.4", features = ["sqlx"], default-features = false } aws-credential-types = { version = "1.2", default-features = false } aws-sdk-bedrockruntime = { version = "1.126", default-features = false } aws-smithy-types = { version = "1.3", default-features = false } criterion = { version = "0.8", default-features = false } darling = { version = "0.23", default-features = false } deadpool = { version = "0.13", default-features = false } document-features = { version = "0.2" } fastembed = { version = "5.5", default-features = false } flv-util = { version = "0.5", default-features = false } htmd = { version = "0.5", default-features = false } ignore = { version = "0.4", default-features = false } proc-macro2 = { version = "1.0", default-features = false } quote = { version = "1.0", default-features = false } redis = { version = "1.0", default-features = false } reqwest = { version = "0.13", default-features = false } secrecy = { version = "0.10", default-features = false } syn = { version = "2.0", default-features = false } tera = { version = "1.20", default-features = false } text-splitter = { version = "0.29", default-features = false } tracing-subscriber = { version = "0.3", default-features = true } tree-sitter = { version = "0.26", default-features = false, features = ["std"] } tree-sitter-java = { version = "0.23", default-features = false } tree-sitter-javascript = { version = "0.25", default-features = false } tree-sitter-python = { version = "0.25", default-features = false } tree-sitter-ruby = { version = "0.23", default-features = false } tree-sitter-rust = { version = "0.24", default-features = false } tree-sitter-typescript = { version = "0.23", default-features = false } tree-sitter-go = { version = "0.25", default-features = false } tree-sitter-solidity = { version = "1.2", default-features = false } tree-sitter-c = { version = "0.24", default-features = false } tree-sitter-cpp = { version = "0.23", default-features = false } tree-sitter-elixir = { version = "0.3.4", default-features = false } tree-sitter-html = { version = "0.23", default-features = false } tree-sitter-php = { version = "0.24", default-features = false } tree-sitter-c-sharp = { version = "0.23", default-features = false } async-anthropic = { version = "0.6.0", default-features = false } duckdb = { version = "1", default-features = false } libduckdb-sys = { version = "1", default-features = false } metrics = { version = "0.24", default-features = false } tiktoken-rs = { version = "0.9", default-features = false } reqwest-eventsource = { version = "0.6", default-features = false } # Testing test-log = { version = "0.2" } testcontainers = { version = "0.27", features = ["http_wait"] } testcontainers-modules = { version = "0.15" } mockall = { version = "0.14" } temp-dir = { version = "0.2" } wiremock = { version = "0.6" } test-case = { version = "3.3" } pretty_assertions = { version = "1.4" } insta = { version = "1.45", features = ["yaml", "filters"] } eventsource-stream = { version = "0.2" } [workspace.lints.rust] unsafe_code = "forbid" unexpected_cfgs = { level = "warn", check-cfg = [ 'cfg(coverage,coverage_nightly)', ] } [workspace.lints.clippy] cargo = { level = "warn", priority = -1 } pedantic = { level = "warn", priority = -1 } blocks_in_conditions = "allow" must_use_candidate = "allow" module_name_repetitions = "allow" missing_fields_in_debug = "allow" # Should be fixed asap multiple_crate_versions = "allow" ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Bosun.ai Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================
Table of Contents - [What is Swiftide?](#what-is-swiftide) - [High level features](#high-level-features) - [Latest updates on our blog :fire:](#latest-updates-on-our-blog-fire) - [Examples](#examples) - [Vision](#vision) - [Features](#features) - [In detail](#in-detail) - [Getting Started](#getting-started) - [Prerequisites](#prerequisites) - [Installation](#installation) - [Usage and concepts](#usage-and-concepts) - [Indexing](#indexing) - [Querying](#querying) - [Contributing](#contributing) - [Core Team Members](#core-team-members) - [License](#license)
![CI](https://img.shields.io/github/actions/workflow/status/bosun-ai/swiftide/test.yml?style=flat-square) ![Coverage Status](https://img.shields.io/coverallsCoverage/github/bosun-ai/swiftide?style=flat-square) [![Crate Badge]][Crate] [![Docs Badge]][API Docs] [![Contributors][contributors-shield]][contributors-url] [![Stargazers][stars-shield]][stars-url] ![Discord](https://img.shields.io/discord/1257672801553354802?style=flat-square&link=https%3A%2F%2Fdiscord.gg%2F3jjXYen9UY) [![MIT License][license-shield]][license-url] [![LinkedIn][linkedin-shield]][linkedin-url]
Logo

Swiftide

Fast, streaming indexing, query, and agentic LLM applications in Rust
Read more on swiftide.rs »

API Docs · Report Bug · Request Feature · Discord

(back to top)

## What is Swiftide? Swiftide is a Rust library for building LLM applications. From performing a simple prompt completion, to building fast, streaming indexing and querying pipelines, to building agents that can use tools and call other agents. ### High level features - Simple primitives for common LLM tasks - Build fast, streaming indexing and querying pipelines - Easily build agents, mix and match with previously built pipelines - A modular and extendable API, with minimal abstractions - Integrations with popular LLMs and storage providers - Ready to use pipeline transformations or bring your own - Build graph like workflows with Tasks - [Langfuse](https://langfuse.com) support
Swiftide overview
Part of the [bosun.ai](https://bosun.ai) project. An upcoming platform for autonomous code improvement. We <3 feedback: project ideas, suggestions, and complaints are very welcome. Feel free to open an issue or contact us on [discord](https://discord.gg/3jjXYen9UY). > [!CAUTION] > Swiftide is under heavy development and can have breaking changes. Documentation might fall short of all features, and despite our efforts be slightly outdated. We recommend to always keep an eye on our [github](https://github.com/bosun-ai/swiftide) and [api documentation](https://docs.rs/swiftide/latest/swiftide/). If you found an issue or have any kind of feedback we'd love to hear from you.

(back to top)

## Latest updates on our blog :fire: - [Swiftide 0.31 - Tasks, Langfuse, Multi-Modal, and more](http://blog.bosun.ai/swiftide-0-31/) - [Swiftide 0.27 - Easy human-in-the-loop flows for agentic AI](http://blog.bosun.ai/swiftide-0-27/) - [Swiftide 0.26 - Streaming agents](http://blog.bosun.ai/swiftide-0-26/) - [Releasing kwaak with kwaak](https://bosun.ai/posts/releasing-kwaak-with-kwaak/) - [Swiftide 0.16 - AI Agents in Rust](https://bosun.ai/posts/swiftide-0-16/) - [Rust in LLM based tools for performance](https://bosun.ai/posts/rust-for-genai-performance/) - [Evaluate Swiftide pipelines with Ragas](https://bosun.ai/posts/evaluating-swiftide-with-ragas/) (2024-09-15) - [Release - Swiftide 0.12](https://bosun.ai/posts/swiftide-0-12/) (2024-09-13) - [Local code intel with Ollama, FastEmbed and OpenTelemetry](https://bosun.ai/posts/ollama-and-telemetry/) (2024-09-04) More on our [blog](https://blog.bosun.ai/)

(back to top)

## Examples Indexing a local code project, chunking into smaller pieces, enriching the nodes with metadata, and persisting into [Qdrant](https://qdrant.tech): ```rust indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) .with_default_llm_client(openai_client.clone()) .filter_cached(Redis::try_from_url( redis_url, "swiftide-examples", )?) .then_chunk(ChunkCode::try_for_language_and_chunk_size( "rust", 10..2048, )?) .then(MetadataQACode::default()) .then(move |node| my_own_thing(node)) .then_in_batch(Embed::new(openai_client.clone())) .then_store_with( Qdrant::builder() .batch_size(50) .vector_size(1536) .build()?, ) .run() .await?; ``` Querying for an example on how to use the query pipeline: ```rust query::Pipeline::default() .then_transform_query(GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(Embed::from_client( openai_client.clone(), )) .then_retrieve(qdrant.clone()) .then_answer(Simple::from_client(openai_client.clone())) .query("How can I use the query pipeline in Swiftide?") .await?; ``` Running an agent that can search code: ```rust #[swiftide::tool( description = "Searches code", param(name = "code_query", description = "The code query") )] async fn search_code( context: &dyn AgentContext, code_query: &str, ) -> Result { let command_output = context .executor() .exec_cmd(&Command::shell(format!("rg '{code_query}'"))) .await?; Ok(command_output.into()) } agents::Agent::builder() .llm(&openai) .tools(vec![search_code()]) .build()? .query("In what file can I find an example of a swiftide agent?") .await?; ``` Agents loop over LLM calls, tool calls, and lifecycle hooks until a final answer is reached. _You can find more detailed examples in [/examples](https://github.com/bosun-ai/swiftide/tree/master/examples)_

(back to top)

## Vision Our goal is to create a fast, extendable platform for building LLM applications in Rust, to further the development of automated AI applications, with an easy-to-use and easy-to-extend api.

(back to top)

## Features - Simple primitives for common LLM tasks - Fast, modular streaming indexing pipeline with async, parallel processing - Experimental query pipeline - Experimental agent framework - A variety of loaders, transformers, semantic chunkers, embedders, and more - Bring your own transformers by extending straightforward traits or use a closure - Splitting and merging pipelines - Jinja-like templating for prompts - Store into multiple backends - Integrations with OpenAI, Groq, Gemini, Anthropic, Redis, Qdrant, Ollama, FastEmbed-rs, Fluvio, LanceDB, and Treesitter - Evaluate pipelines with RAGAS - Sparse vector support for hybrid search - `tracing` supported for logging and tracing, see /examples and the `tracing` crate for more information. - Tracing layer for exporting to Langfuse ### In detail | **Feature** | **Details** | | -------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | **Supported Large Language Model providers** | OpenAI (and Azure)
Anthropic
Gemini
OpenRouter
AWS Bedrock (Converse API)
Groq - All models
Ollama - All models | | **Agents** | All the boiler plate for autonomous agents so you don't have to | | **Tasks** | Build graph like workflows with tasks, combining all the above to build complex applications | | **Loading data** | Files
Scraping
Fluvio
Parquet
Kafka
Other pipelines and streams | | **Example and pre-build transformers and metadata generation** | Generate Question and answerers for both text and code (Hyde)
Summaries, titles and queries via an LLM
Extract definitions and references with tree-sitter | | **Splitting and chunking** | Markdown
Text (text_splitter)
Code (with tree-sitter) | | **Storage** | Qdrant
Redis
LanceDB
Postgres
Duckdb | | **Query pipeline** | Similarity and hybrid search, query and response transformations, and evaluation |

(back to top)

## Getting Started ### Prerequisites Make sure you have the rust toolchain installed. [rustup](https://rustup.rs) Is the recommended approach. To use OpenAI, an API key is required. Note that by default `async_openai` uses the `OPENAI_API_KEY` environment variables. Other integrations might have their own requirements. ### Installation 1. Set up a new Rust project 2. Add swiftide ```sh cargo add swiftide ``` 3. Enable the features of integrations you would like to use in your `Cargo.toml` 4. Write a pipeline (see our examples and documentation)

(back to top)

## Usage and concepts Before building your streams, you need to enable and configure any integrations required. See /examples. _We have a lot of examples, please refer to /examples and the [Documentation](https://docs.rs/swiftide/latest/swiftide/)_ > [!NOTE] > No integrations are enabled by default as some are code heavy. We recommend you to cherry-pick the integrations you need. By convention flags have the same name as the integration they represent. ### Indexing An indexing stream starts with a Loader that emits Nodes. For instance, with the Fileloader each file is a Node. You can then slice and dice, augment, and filter nodes. Each different kind of step in the pipeline requires different traits. This enables extension. Nodes are generic over their inner type. This is a transition in progress, but when you BYO, feel free to slice and dice. The inner type can change midway through the pipeline. - **from_loader** `(impl Loader)` starting point of the stream, creates and emits Nodes - **filter_cached** `(impl NodeCache)` filters cached nodes - **then** `(impl Transformer)` transforms the node and puts it on the stream - **then_in_batch** `(impl BatchTransformer)` transforms multiple nodes and puts them on the stream - **then_chunk** `(impl ChunkerTransformer)` transforms a single node and emits multiple nodes - **then_store_with** `(impl Storage)` stores the nodes in a storage backend, this can be chained Additionally, several generic transformers are implemented. They take implementers of `SimplePrompt` and `EmbedModel` to do their things. > [!WARNING] > Due to the performance, chunking before adding metadata gives rate limit errors on OpenAI very fast, especially with faster models like gpt-5-nano. Be aware. The `async-openai` crate provides an exmponential backoff strategy. If that is still a problem, there is also a decorator that supports streaming in `swiftide_core/indexing_decorators`. ### Querying A query stream starts with a search strategy. In the query pipeline a `Query` goes through several stages. Transformers and retrievers work together to get the right context into a prompt, before generating an answer. Transformers and Retrievers operate on different stages of the Query via a generic statemachine. Additionally, the search strategy is generic over the pipeline and Retrievers need to implement specifically for each strategy. That sounds like a lot but, tl&dr; the query pipeline is _fully and strongly typed_. - **Pending** The query has not been executed, and can be further transformed with transformers - **Retrieved** Documents have been retrieved, and can be further transformed to provide context for an answer - **Answered** The query is done Additionally, query pipelines can also be evaluated. I.e. by [Ragas](https://ragas.io). Similar to the indexing pipeline each step is governed by simple Traits and closures implement these traits as well.

(back to top)

## Contributing Swiftide is in a very early stage and we are aware that we lack features for the wider community. Contributions are very welcome. :tada: If you have a great idea, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". Don't forget to give the project a star! Thanks again! Indexing and querying are performance sensitive tasks. Please make sure to consider allocations and performance when contributing. AI Generated code is welcome and not frowned upon. Please be genuine and think critically about what you add. If you just want to contribute (bless you!), see [our issues](https://github.com/bosun-ai/swiftide/issues) or join us on Discord. 1. Fork the Project 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 3. Commit your Changes (`git commit -m 'feat: Add some AmazingFeature'`) 4. Push to the Branch (`git push origin feature/AmazingFeature`) 5. Open a Pull Request AI Agents can refer to [AGENTS.md](AGENTS.md) for workspace layout, commands, and expectations tailored to agents.

(back to top)

## Core Team Members

timonv
open for swiftide consulting

tinco

## License Distributed under the MIT License. See `LICENSE` for more information.

(back to top)

[contributors-shield]: https://img.shields.io/github/contributors/bosun-ai/swiftide.svg?style=flat-square [contributors-url]: https://github.com/bosun-ai/swiftide/graphs/contributors [stars-shield]: https://img.shields.io/github/stars/bosun-ai/swiftide.svg?style=flat-square [stars-url]: https://github.com/bosun-ai/swiftide/stargazers [license-shield]: https://img.shields.io/github/license/bosun-ai/swiftide.svg?style=flat-square [license-url]: https://github.com/bosun-ai/swiftide/blob/master/LICENSE.txt [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555 [linkedin-url]: https://www.linkedin.com/company/bosun-ai [Crate Badge]: https://img.shields.io/crates/v/swiftide?logo=rust&style=flat-square&logoColor=E05D44&color=E05D44 [Crate]: https://crates.io/crates/swiftide [Docs Badge]: https://img.shields.io/docsrs/swiftide?logo=rust&style=flat-square&logoColor=E05D44 [API Docs]: https://docs.rs/swiftide ================================================ FILE: benchmarks/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "benchmarks" publish = false version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dev-dependencies] tokio = { workspace = true, features = ["full"] } swiftide = { path = "../swiftide", features = ["all", "redb"] } serde_json = { workspace = true } criterion = { workspace = true, features = ["html_reports", "async_tokio"] } anyhow = { workspace = true } futures-util = { workspace = true } testcontainers = { workspace = true, features = ["blocking"] } temp-dir = { workspace = true } [[bench]] name = "fileloader" path = "fileloader.rs" harness = false [[bench]] name = "index-readme-local" path = "local_pipeline.rs" harness = false [[bench]] name = "node-cache" path = "node_cache_comparison.rs" harness = false ================================================ FILE: benchmarks/fileloader.rs ================================================ use std::hint::black_box; use anyhow::Result; use criterion::{Criterion, criterion_group, criterion_main}; use futures_util::stream::{StreamExt, TryStreamExt}; use swiftide::traits::Loader; async fn run_fileloader(num_files: usize) -> Result { let mut total_nodes = 0; let mut stream = swiftide::indexing::loaders::FileLoader::new("./benchmarks/fileloader.rs") .with_extensions(&["rs"]) .into_stream() .take(num_files); while stream.try_next().await?.is_some() { total_nodes += 1; } assert!(total_nodes == num_files); Ok(total_nodes) } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("load_1", |b| b.iter(|| run_fileloader(black_box(1)))); c.bench_function("load_10", |b| b.iter(|| run_fileloader(black_box(10)))); } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); ================================================ FILE: benchmarks/local_pipeline.rs ================================================ use anyhow::Result; use criterion::{Criterion, criterion_group, criterion_main}; use swiftide::{ indexing::Pipeline, indexing::loaders::FileLoader, indexing::persist::MemoryStorage, indexing::transformers::{ChunkMarkdown, Embed}, integrations::fastembed::FastEmbed, }; async fn run_pipeline() -> Result<()> { Pipeline::from_loader(FileLoader::new("README.md").with_extensions(&["md"])) .then_chunk(ChunkMarkdown::from_chunk_range(20..256)) .then_in_batch(Embed::new(FastEmbed::builder().batch_size(10).build()?)) .then_store_with(MemoryStorage::default()) .run() .await } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("run_local_pipeline", |b| b.iter(run_pipeline)); } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); ================================================ FILE: benchmarks/node_cache_comparison.rs ================================================ use anyhow::Result; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use swiftide::indexing::transformers::ChunkCode; use swiftide::{ indexing::{Pipeline, loaders::FileLoader, persist::MemoryStorage}, traits::NodeCache, }; use temp_dir::TempDir; use testcontainers::Container; use testcontainers::{ GenericImage, core::{IntoContainerPort, WaitFor}, runners::SyncRunner, }; async fn run_pipeline(node_cache: Box>) -> Result<()> { Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) .filter_cached(node_cache) .then_chunk(ChunkCode::try_for_language_and_chunk_size("rust", 10..256)?) .then_store_with(MemoryStorage::default()) .run() .await } fn criterion_benchmark(c: &mut Criterion) { let redis_container = start_redis(); let redis_url = format!( "redis://{host}:{port}", host = redis_container.get_host().unwrap(), port = redis_container.get_host_port_ipv4(6379).unwrap() ); let redis: Box> = Box::new( swiftide::integrations::redis::Redis::try_from_url(redis_url, "criterion").unwrap(), ); let tempdir = TempDir::new().unwrap(); let redb: Box> = Box::new( swiftide::integrations::redb::Redb::builder() .database_path(tempdir.child("criterion")) .build() .unwrap(), ); let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); for node_cache in [(redis, "redis"), (redb, "redb")] { c.bench_with_input( BenchmarkId::new("node_cache", node_cache.1), &node_cache, |b, s| { let cache_clone = s.0.clone(); runtime.spawn_blocking(move || async move { cache_clone.clear().await.unwrap() }); b.to_async(&runtime).iter(|| run_pipeline(s.0.clone())) }, ); } } fn start_redis() -> Container { GenericImage::new("redis", "7.2.4") .with_exposed_port(6379.tcp()) .with_wait_for(WaitFor::message_on_stdout("Ready to accept connections")) .start() .expect("Redis started") } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); ================================================ FILE: benchmarks/output.txt ================================================ test load_1 ... bench: 6 ns/iter (+/- 0) test load_10 ... bench: 6 ns/iter (+/- 0) test run_local_pipeline ... bench: 846 ns/iter (+/- 7) ================================================ FILE: cliff.toml ================================================ [remote.github] owner = "bosun-ai" repo = "swiftide" [git] commit_parsers = [ { message = "(r|R)elease", skip = true }, { message = "^(feat|fix|perf|chore)\\(ci\\)", group = "Miscellaneous" }, { message = "^feat*", group = "New features" }, { message = "^fix*", group = "Bug fixes" }, { message = "^perf*", group = "Performance" }, { message = "^chore*", group = "Miscellaneous" }, ] [changelog] # changelog header header = """ # Changelog All notable changes to this project will be documented in this file. """ body = """ {%- if not version %} ## [unreleased] {% else -%} ## [{{ version }}](https://github.com/bosun-ai/swiftide/releases/tag/{{ version }}) - {{ timestamp | date(format="%Y-%m-%d") }} {% endif -%} {% macro commit(commit) -%} - [{{ commit.id | truncate(length=7, end="") }}]({{ "https://github.com/bosun-ai/swiftide/commit/" ~ commit.id }}) \ {% if commit.scope %}*({{commit.scope | default(value = "uncategorized") | lower }})* {% endif %}\ {%- if commit.breaking %} [**breaking**]{% endif %} \ {{ commit.message | upper_first | trim }}\ {% if commit.remote.username %} by @{{ commit.remote.username }}{%- endif -%}\ {%- if commit.links %} \ in {% for link in commit.links %}[{{link.text}}]({{link.href}}) {% endfor -%}\ {% endif %} {%- if commit.body and commit.remote.username and commit.remote.username is not containing("[bot]") %} ````text {#- 4 backticks escape any backticks in body #} {{commit.body | indent(prefix=" ") }} ```` {%- endif %} {%- if commit.breaking_description %} **BREAKING CHANGE**: {{ commit.breaking_description }} {%- endif %} {% endmacro -%} {% for group, commits in commits | group_by(attribute="group") %} ### {{ group | striptags | trim | upper_first }} {% for commit in commits | filter(attribute="scope") | sort(attribute="scope") %} {{ self::commit(commit=commit) }} {%- endfor -%} {% for commit in commits %} {%- if not commit.scope %} {{ self::commit(commit=commit) }} {%- endif -%} {%- endfor -%} {%- endfor %} {%- if github.contributors -%} {% if github.contributors | filter(attribute="is_first_time", value=true) | length != 0 %} ### New Contributors {%- endif %}\ {% for contributor in github.contributors | filter(attribute="is_first_time", value=true) %} * @{{ contributor.username }} made their first contribution {%- if contributor.pr_number %} in \ [#{{ contributor.pr_number }}]({{ self::remote_url() }}/pull/{{ contributor.pr_number }}) \ {%- endif %} {%- endfor -%} {% endif -%} {% if version %} {% if previous.version %} **Full Changelog**: {{ self::remote_url() }}/compare/{{ previous.version }}...{{ version }} {% endif %} {% else -%} {% raw %}\n{% endraw %} {% endif %} {%- macro remote_url() -%} {%- if remote.github -%} https://github.com/{{ remote.github.owner }}/{{ remote.github.repo }}\ {% else -%} https://github.com/bosun-ai/swiftide {%- endif -%} {% endmacro %} """ # template for the changelog body # https://keats.github.io/tera/docs/#introduction # note that the - before / after the % controls whether whitespace is rendered between each line. # Getting this right so that the markdown renders with the correct number of lines between headings # code fences and list items is pretty finicky. Note also that the 4 backticks in the commit macro # is intentional as this escapes any backticks in the commit body. # remove the leading and trailing whitespace from the template trim = false # changelog footer ================================================ FILE: deny.toml ================================================ [graph] all-features = true [licenses] confidence-threshold = 0.8 allow = [ "Apache-2.0", "BSD-2-Clause", "BSD-3-Clause", "ISC", "MIT", "Unicode-DFS-2016", "MPL-2.0", "Apache-2.0 WITH LLVM-exception", "Unlicense", "CC0-1.0", "zlib-acknowledgement", "Zlib", "0BSD", "Unicode-3.0", "NCSA", ] exceptions = [{ allow = ["OpenSSL"], crate = "ring" }] [advisories] version = 2 ignore = [ { id = "RUSTSEC-2023-0086", reason = "Ignore a security adivisory on lexical-core" }, { id = "RUSTSEC-2021-0141", reason = "Dotenv is used by spider" }, { id = "RUSTSEC-2024-0384", reason = "Instant is unmaintained" }, { id = "RUSTSEC-2024-0421", reason = "Older version of idna used by reqwest" }, ] [bans] multiple-versions = "allow" [sources] unknown-registry = "deny" unknown-git = "warn" allow-registry = ["https://github.com/rust-lang/crates.io-index"] [[licenses.clarify]] crate = "ring" # SPDX considers OpenSSL to encompass both the OpenSSL and SSLeay licenses # https://spdx.org/licenses/OpenSSL.html # ISC - Both BoringSSL and ring use this for their new files # MIT - "Files in third_party/ have their own licenses, as described therein. The MIT # license, for third_party/fiat, which, unlike other third_party directories, is # compiled into non-test libraries, is included below." # OpenSSL - Obviously expression = "ISC AND MIT AND OpenSSL" license-files = [{ path = "LICENSE", hash = 0xbd0eed23 }] ================================================ FILE: examples/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-examples" publish = false version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] tokio = { workspace = true, features = ["full"] } swiftide = { path = "../swiftide/", features = [ "all", "scraping", "aws-bedrock", "groq", "ollama", "fluvio", "kafka", "lancedb", "pgvector", "swiftide-agents", "dashscope", "mcp", "anthropic", "gemini", "metrics", "langfuse", ] } swiftide-macros = { path = "../swiftide-macros" } tracing-subscriber = { workspace = true, features = ["fmt", "env-filter"] } serde_json = { workspace = true } spider = { workspace = true } fluvio = { workspace = true } temp-dir = { workspace = true } anyhow = { workspace = true } futures-util = { workspace = true } sqlx = { workspace = true } swiftide-test-utils = { path = "../swiftide-test-utils" } tracing = { workspace = true } serde = { workspace = true } rmcp = { workspace = true, features = [ "transport-child-process", "client", "server", ] } metrics = { workspace = true } schemars.workspace = true base64 = { workspace = true } [[example]] doc-scrape-examples = true name = "index-codebase" path = "index_codebase.rs" [[example]] name = "index-codebase-reduced-context" path = "index_codebase_reduced_context.rs" [[example]] doc-scrape-examples = true name = "fastembed" path = "fastembed.rs" [[example]] doc-scrape-examples = true name = "index-redis" path = "index_into_redis.rs" [[example]] doc-scrape-examples = true name = "index-markdown-metadata" path = "index_markdown_lots_of_metadata.rs" [[example]] doc-scrape-examples = true name = "scraping-index" path = "scraping_index_to_markdown.rs" [[example]] doc-scrape-examples = true name = "aws-bedrock" path = "aws_bedrock.rs" [[example]] name = "aws-bedrock-agent" path = "aws_bedrock_agent.rs" [[example]] doc-scrape-examples = true name = "store-multiple-vectors" path = "store_multiple_vectors.rs" [[example]] name = "index-groq" path = "index_groq.rs" [[example]] name = "index-ollama" path = "index_ollama.rs" [[example]] name = "query-pipeline" path = "query_pipeline.rs" [[example]] name = "hybrid-search" path = "hybrid_search.rs" [[example]] name = "fluvio" path = "fluvio.rs" [[example]] name = "kakfa" path = "kafka.rs" [[example]] name = "lancedb" path = "lancedb.rs" [[example]] name = "describe-image" path = "describe_image.rs" [[example]] name = "hello-agents" path = "hello_agents.rs" [[example]] name = "index-md-pgvector" path = "index_md_into_pgvector.rs" [[example]] name = "dashscope" path = "dashscope.rs" [[example]] name = "reranking" path = "reranking.rs" [[example]] name = "agents-mcp" path = "agents_mcp_tools.rs" [[example]] name = "agents-resume" path = "agents_resume.rs" [[example]] name = "streaming-agents" path = "streaming_agents.rs" [[example]] name = "agents-hitl" path = "agents_with_human_in_the_loop.rs" [[example]] name = "usage-metrics" path = "usage_metrics.rs" [[example]] name = "tasks" path = "tasks.rs" [[example]] name = "agent-can-fail-custom-schema" path = "agent_can_fail_custom_schema.rs" [[example]] name = "stop-with-args-custom-schema" path = "stop_with_args_custom_schema.rs" [[example]] name = "responses-api" path = "responses_api.rs" [[example]] name = "responses-api-reasoning" path = "responses_api_reasoning.rs" [[example]] name = "structured-prompt" path = "structured_prompt.rs" [[example]] name = "langfuse" path = "langfuse.rs" [[example]] name = "tool-custom-schema" path = "tool_custom_schema.rs" ================================================ FILE: examples/agent_can_fail_custom_schema.rs ================================================ //! Demonstrates how to replace the default failure arguments for `AgentCanFail` with a custom //! JSON schema and capture the structured failure payload when the agent stops. //! //! Set the `OPENAI_API_KEY` environment variable before running the example. The agent is guided //! to use the `task_failed` tool with the schema defined below whenever it cannot complete the //! task. use anyhow::Result; use schemars::{JsonSchema, Schema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::{self, to_string_pretty}; use swiftide::agents::tools::control::AgentCanFail; use swiftide::agents::{Agent, StopReason}; use swiftide::traits::Tool; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] enum FailureCategory { MissingDependency, PermissionDenied, UnexpectedRegression, } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] enum RemediationStatus { Planned, Blocked, Complete, } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[schemars(deny_unknown_fields)] struct FailureReport { category: FailureCategory, summary: String, impact: String, #[serde(default, skip_serializing_if = "Vec::is_empty")] recommended_actions: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] remediation_status: Option, } fn failure_schema() -> Schema { schema_for!(FailureReport) } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let schema = failure_schema(); let failure_tool = AgentCanFail::with_parameters_schema(schema.clone()); println!( "task_failed tool schema:\n{}", to_string_pretty(&failure_tool.tool_spec())?, ); let openai = swiftide::integrations::openai::OpenAI::builder() .default_prompt_model("gpt-4o-mini") .default_embed_model("text-embedding-3-small") .build()?; let mut builder = Agent::builder(); builder .llm(&openai) .tools([failure_tool.clone()]) .on_stop(|_, reason, _| { Box::pin(async move { if let StopReason::AgentFailed(Some(payload)) = reason { let json = to_string_pretty(&payload).unwrap(); println!("agent reported failure:\n{json}"); } Ok(()) }) }); if let Some(prompt) = builder.system_prompt_mut() { prompt .with_role("Incident response coordinator") .with_guidelines([ "If the task cannot be completed, call the `task_failed` tool using the provided JSON schema.", "Populate all required fields and list at least one `recommended_actions` entry.", "Clearly document the impact so downstream teams can prioritize remediation.", ]) .with_constraints(["Do not claim success when blockers remain unresolved."]); } let mut agent = builder.build()?; agent .query_once( "You must restore last night's database backup, but the only backup file is corrupted and no redundant copy exists. Report the failure.", ) .await?; Ok(()) } ================================================ FILE: examples/agents_mcp_tools.rs ================================================ //! This is an example of how to build a Swiftide agent with tools using the MCP protocol. //! //! The agent in this example prints all messages using a channel. use anyhow::Result; use rmcp::{ ServiceExt as _, model::{ClientInfo, Implementation}, transport::{ConfigureCommandExt as _, TokioChildProcess}, }; use swiftide::agents::{self, tools::mcp::McpToolbox}; #[tokio::main] async fn main() -> Result<()> { println!("Hello, agents!"); let openai = swiftide::integrations::openai::OpenAI::builder() .default_embed_model("text-embeddings-3-small") .default_prompt_model("gpt-4o-mini") .build()?; let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); tokio::spawn(async move { while let Some(msg) = rx.recv().await { println!("{msg}"); } }); // First set up our client info to identify ourselves to the server let client_info = ClientInfo { client_info: Implementation { name: "swiftide-example".into(), version: env!("CARGO_PKG_VERSION").into(), title: None, description: None, icons: None, website_url: None, }, ..Default::default() }; // Use `rmcp` to start the server let running_service = client_info .serve(TokioChildProcess::new( tokio::process::Command::new("npx").configure(|cmd| { cmd.args(["-y", "@modelcontextprotocol/server-everything"]); }), )?) .await?; // Create a toolbox from the running server, and only use the `add` tool // // A toolbox reveals it's tools to the swiftide agent the first time it starts (if the state of // the agent was pending). You can add as many toolboxes as you want. MCP services are an // implmenentation of a toolbox. A list of tools is another. let everything_toolbox = McpToolbox::from_running_service(running_service) .with_whitelist(["add"]) .to_owned(); agents::Agent::builder() .llm(&openai) // Add the toolbox to the agent .add_toolbox(everything_toolbox) // Every message added by the agent will be printed to stdout .on_new_message(move |_, msg| { let msg = msg.to_string(); let tx = tx.clone(); Box::pin(async move { tx.send(msg).unwrap(); Ok(()) }) }) .build()? .query("Use the add tool to add 1 and 2") .await?; Ok(()) } ================================================ FILE: examples/agents_resume.rs ================================================ //! This example illustrates how to resume an agent from existing messages. use anyhow::Result; use swiftide::agents::{self, DefaultContext}; #[tokio::main] async fn main() -> Result<()> { println!("Hello, agents!"); let openai = swiftide::integrations::openai::OpenAI::builder() .default_embed_model("text-embeddings-3-small") .default_prompt_model("gpt-4o-mini") .build()?; let mut first_agent = agents::Agent::builder().llm(&openai).build()?; first_agent.query("Say hello!").await?; // Let's store the messages in a database, retrieve them back, and start a new agent let stored_history = serde_json::to_string(&first_agent.history().await?)?; let retrieved_history: Vec<_> = serde_json::from_str(&stored_history)?; let restored_context = DefaultContext::default() .with_existing_messages(retrieved_history) .await? .to_owned(); let mut second_agent = agents::Agent::builder() .llm(&openai) .context(restored_context) // We'll use the one from the first agent, alternatively we could also pop it from the // previous history and add a new one here .no_system_prompt() .build()?; second_agent.query("What did you say?").await?; Ok(()) } ================================================ FILE: examples/agents_with_human_in_the_loop.rs ================================================ //! This is an example of using a human in the loop pattern with switfide agents. //! //! In the example we send the tool call over an channel, and then manually approve it. //! //! In a more realistic example, you can use other rust primitives to make it work for your //! usecase. I.e., make an api request with a callback url that will add the feedback. //! //! Both requesting feedback and providing feedback support an optional payload (as a //! `serde_json::Value`). //! //! This allows for more custom workflows, to either display or provide more input to the //! underlying tool call. //! //! For an example on how to implement your own custom wrappers, refer to //! `tools::control::ApprovalRequired` use anyhow::Result; use swiftide::{ agents::{self, StopReason, tools::control::ApprovalRequired}, chat_completion::{ToolOutput, errors::ToolError}, traits::{AgentContext, ToolFeedback}, }; use tracing_subscriber::EnvFilter; #[swiftide::tool( description = "Guess a number", param(name = "number", description = "Number to guess") )] async fn guess_a_number( _context: &dyn AgentContext, number: usize, ) -> Result { let actual_number = 42; if number == actual_number { Ok("You guessed it!".into()) } else { Ok("Try again!".into()) } } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt() .compact() .with_env_filter(EnvFilter::from_default_env()) .init(); println!("Hello, agents!"); let openai = swiftide::integrations::openai::OpenAI::builder() .default_prompt_model("gpt-4o") .build()?; // ApprovalRequired is a simple wrapper. You can also implement your own approval // flows by returning a `ToolOutput::FeedbackRequired` in a tool, // you can then use `has_received_feedback` and `received_feedback` on the context // to build your custom workflow. let guess_with_approval = ApprovalRequired::new(guess_a_number()); let mut agent = agents::Agent::builder() .llm(&openai) .tools(vec![guess_with_approval]) // Every message added by the agent will be printed to stdout .on_new_message(move |_, msg| { println!("{msg}"); Box::pin(async move { Ok(()) }) }) .limit(5) .build()?; // First query the agent, the agent will stop with a reason that feedback is required agent .query("Guess a number between 0 and 100 using the `guess_a_number` tool") .await?; // The agent stopped, lets get the tool call let Some(StopReason::FeedbackRequired { tool_call, .. }) = agent.stop_reason() else { panic!("expected a tool call to approve") }; // Alternatively, you can also get the stop reason from the agent state // agent.state().stop_reason().unwrap().feedback_required().unwrap() // Register that this tool call is ok. println!("Approving number guessing"); agent .context() .feedback_received(tool_call, &ToolFeedback::approved()) .await .unwrap(); // Run the agent again and it will pick up where it stopped. agent.run().await.unwrap(); Ok(()) } ================================================ FILE: examples/aws_bedrock.rs ================================================ //! # [Swiftide] Aws Bedrock example //! //! This example demonstrates how to use the `AwsBedrock` v2 integration to interact with Bedrock //! service. //! //! To use bedrock you will need the following: //! - AWS cli or environment variables configured //! - An aws region configured //! - Access to the bedrock models you want to use //! - A model id or arn //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples //! [AWS Bedrock documentation]: https://docs.aws.amazon.com/bedrock/ use swiftide::{ indexing, indexing::loaders::FileLoader, indexing::persist::MemoryStorage, indexing::transformers, integrations, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let aws_bedrock = integrations::aws_bedrock_v2::AwsBedrock::builder() .default_prompt_model("global.anthropic.claude-haiku-4-5-20251001-v1:0") .build()?; let memory_storage = MemoryStorage::default(); indexing::Pipeline::from_loader(FileLoader::new("./README.md")) .log_nodes() .then_chunk(transformers::ChunkMarkdown::from_chunk_range(100..512)) .then(transformers::MetadataSummary::new(aws_bedrock.clone())) .then_store_with(memory_storage.clone()) .log_all() .run() .await?; println!("Summaries:"); println!( "{}", memory_storage .get_all_values() .await .iter() .filter_map(|n| n.metadata.get("Summary").map(|v| v.to_string())) .collect::>() .join("\n---\n") ); Ok(()) } ================================================ FILE: examples/aws_bedrock_agent.rs ================================================ //! # [Swiftide] AWS Bedrock Agent Example //! //! This example demonstrates a simple agent setup with `AwsBedrock` v2. //! //! Requirements: //! - AWS credentials and region configured (CLI profile or environment variables) //! - Access to the Bedrock model you choose //! - A model with tool use support (the Claude model below supports this) //! //! [Swiftide]: https://github.com/bosun-ai/swiftide use anyhow::Result; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use swiftide::{ agents, chat_completion::{ToolOutput, errors::ToolError}, integrations::aws_bedrock_v2::AwsBedrock, traits::{AgentContext, Command}, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] struct FormatTimestampRequest { /// Prefix to prepend to the timestamp. prefix: String, /// Timestamp to format. timestamp: String, } #[swiftide::tool(description = "Get the current UTC date and time in RFC3339 format")] async fn current_utc_time(context: &dyn AgentContext) -> Result { let command_output = context .executor() .exec_cmd(&Command::shell("date -u +\"%Y-%m-%dT%H:%M:%SZ\"")) .await?; Ok(command_output.into()) } #[swiftide::tool( description = "Format a timestamp with a caller-provided prefix", param(name = "request", description = "Timestamp formatting input") )] async fn format_timestamp( _context: &dyn AgentContext, request: FormatTimestampRequest, ) -> Result { Ok(ToolOutput::text(format!( "{}{}", request.prefix, request.timestamp ))) } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let bedrock = AwsBedrock::builder() .default_prompt_model("global.anthropic.claude-sonnet-4-6") .build()?; let mut agent = agents::Agent::builder() .llm(&bedrock) .tools(vec![current_utc_time(), format_timestamp()]) .on_new_message(|_, msg| { let rendered = msg.to_string(); Box::pin(async move { println!("{rendered}"); Ok(()) }) }) .limit(6) .build()?; agent .query( "Call current_utc_time once. Then call format_timestamp with prefix \"UTC now: \" and \ that timestamp. After that, report the formatted result and stop.", ) .await?; Ok(()) } ================================================ FILE: examples/dashscope.rs ================================================ use swiftide::{ indexing::{ self, EmbeddedField, loaders::FileLoader, transformers::{ChunkMarkdown, Embed, MetadataQAText, metadata_qa_text}, }, integrations::{dashscope::DashscopeBuilder, lancedb::LanceDB}, query::{ self, answers::{self}, query_transformers::{self}, response_transformers, }, }; use temp_dir::TempDir; #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); let client = DashscopeBuilder::default() .default_embed_model("text-embedding-v2") .default_prompt_model("qwen-long") .build()?; let tempdir = TempDir::new().unwrap(); let lancedb = LanceDB::builder() .uri(tempdir.child("lancedb").to_str().unwrap()) .vector_size(1536) .with_vector(EmbeddedField::Combined) .with_metadata(metadata_qa_text::NAME) .table_name("swiftide_test") .build() .unwrap(); indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["md"])) .with_default_llm_client(client.clone()) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(client.clone())) .then_in_batch(Embed::new(client.clone()).with_batch_size(10)) .then_store_with(lancedb.clone()) .run() .await?; let pipeline = query::Pipeline::default() .then_transform_query(query_transformers::GenerateSubquestions::from_client( client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(client.clone())) .then_retrieve(lancedb.clone()) .then_transform_response(response_transformers::Summary::from_client(client.clone())) .then_answer(answers::Simple::from_client(client.clone())); let result = pipeline .query("What is swiftide? Please provide an elaborate explanation") .await?; println!("===="); println!("{:?}", result.answer()); Ok(()) } ================================================ FILE: examples/describe_image.rs ================================================ //! Demonstrates passing an image to Chat Completions using a data URL. //! //! Set the `OPENAI_API_KEY` environment variable before running. use anyhow::{Context as _, Result}; use base64::{Engine as _, engine::general_purpose}; use swiftide::chat_completion::{ChatCompletionRequest, ChatMessage, ChatMessageContentPart}; use swiftide::traits::ChatCompletion; #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let openai = swiftide::integrations::openai::OpenAI::builder() .default_prompt_model("gpt-4o-mini") .build()?; let image_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../images/logo.png"); let image_bytes = std::fs::read(&image_path).with_context(|| format!("Read {image_path:?}"))?; let encoded = general_purpose::STANDARD.encode(&image_bytes); let data_url = format!("data:image/png;base64,{encoded}"); let message = ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::text("Describe this image in one sentence."), ChatMessageContentPart::image(data_url), ]); let request = ChatCompletionRequest::builder() .messages(vec![message]) .build()?; let response = openai.complete(&request).await?; println!( "Image description: {}", response.message().unwrap_or("") ); Ok(()) } ================================================ FILE: examples/fastembed.rs ================================================ //! # [Swiftide] Indexing the Swiftide itself example //! //! This example demonstrates how to index the Swiftide codebase itself using FastEmbed. //! //! The pipeline will: //! - Load all `.rs` files from the current directory //! - Embed the chunks in batches of 10 using FastEmbed //! - Store the nodes in Qdrant //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing, indexing::loaders::FileLoader, indexing::transformers::Embed, integrations::{fastembed::FastEmbed, qdrant::Qdrant}, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let qdrant_url = std::env::var("QDRANT_URL") .as_deref() .unwrap_or("http://localhost:6334") .to_owned(); indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) .then_in_batch(Embed::new(FastEmbed::builder().batch_size(10).build()?)) .then_store_with( Qdrant::try_from_url(qdrant_url)? .batch_size(50) .vector_size(384) .collection_name("swiftide-examples-fastembed".to_string()) .build()?, ) .run() .await?; Ok(()) } ================================================ FILE: examples/fluvio.rs ================================================ //! # [Swiftide] Loading data from Fluvio //! //! This example demonstrates how to index the Swiftide codebase itself. //! Note that for it to work correctly you need to have OPENAI_API_KEY set, redis and qdrant //! running. //! //! The pipeline will: //! - Load all `.rs` files from the current directory //! - Skip any nodes previously processed; hashes are based on the path and chunk (not the //! metadata!) //! - Run metadata QA on each chunk; generating questions and answers and adding metadata //! - Chunk the code into pieces of 10 to 2048 bytes //! - Embed the chunks in batches of 10, Metadata is embedded by default //! - Store the nodes in Qdrant //! //! Note that metadata is copied over to smaller chunks when chunking. When making LLM requests //! with lots of small chunks, consider the rate limits of the API. //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing::{self, transformers::Embed}, integrations::{ fastembed::FastEmbed, fluvio::{ConsumerConfigExt, Fluvio}, qdrant::Qdrant, }, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); static TOPIC_NAME: &str = "hello-rust"; static PARTITION_NUM: u32 = 0; let loader = Fluvio::builder() .consumer_config_ext( ConsumerConfigExt::builder() .topic(TOPIC_NAME) .partition(PARTITION_NUM) .offset_start(fluvio::Offset::from_end(1)) .build() .unwrap(), ) .build() .unwrap(); indexing::Pipeline::from_loader(loader) .then_in_batch(Embed::new(FastEmbed::try_default().unwrap()).with_batch_size(10)) .then_store_with( Qdrant::builder() .batch_size(50) .vector_size(384) .collection_name("swiftide-examples") .build()?, ) .run() .await?; Ok(()) } ================================================ FILE: examples/hello_agents.rs ================================================ //! This is an example of how to build a Swiftide agent //! //! A swiftide agent runs completions in a loop, optionally with tools, to complete a task //! autonomously. Agents stop when either the LLM calls the always included `stop` tool, or //! (configurable) if the last message in the completion chain was from the assistant. //! //! Tools can be created by using the `tool` attribute macro as shown here. For more control (i.e. //! internal state), there //! is also a `Tool` derive macro for convenience. Anything that implements the `Tool` trait can //! act as a tool. //! //! Agents operate on an `AgentContext`, which is responsible for managaging the completion history //! and providing access to the outside world. For the latter, the context is expected to have a //! `ToolExecutor`, which by default runs locally. //! //! When building the agent, hooks are available to influence the state, completions, and general //! behaviour of the agent. Hooks are also traits. //! //! Refer to the api documentation for more detailed information. use anyhow::Result; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use swiftide::{ agents, chat_completion::{ToolOutput, errors::ToolError}, traits::{AgentContext, Command}, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] struct CodeSearchRequest { /// Search query to pass to ripgrep query: String, /// Optional repository root (defaults to the current working directory) repo: Option, /// Optional list of glob filters for the search file_globs: Option>, } #[swiftide::tool( description = "Searches code", param(name = "request", description = "Code search parameters") )] async fn search_code( context: &dyn AgentContext, request: CodeSearchRequest, ) -> Result { let repo = request.repo.as_deref().unwrap_or("."); let mut command = format!("cd {repo} && rg '{query}'", query = request.query); if let Some(globs) = &request.file_globs { for glob in globs { command.push_str(&format!(" -g '{glob}'")); } } let command_output = context .executor() .exec_cmd(&Command::shell(command)) .await?; Ok(command_output.into()) } const READ_FILE: &str = "Read a file"; #[swiftide::tool( description = READ_FILE, param(name = "path", description = "Path to the file") )] async fn read_file(context: &dyn AgentContext, path: &str) -> Result { let command_output = context .executor() .exec_cmd(&Command::shell(format!("cat {path}"))) .await?; Ok(command_output.into()) } // The macro understands common Rust types (strings, numbers, bools, vectors, maps, structs, etc.) // and automatically derives a JSON Schema via `schemars`. If you need to tweak the schema // manually, implement the `Tool` trait and attach your own `parameters_schema`. // #[swiftide::tool( description = "Guess a number", param(name = "number", description = "Number to guess") )] async fn guess_a_number( _context: &dyn AgentContext, number: usize, ) -> Result { let actual_number = 42; if number == actual_number { Ok("You guessed it!".into()) } else { Ok("Try again!".into()) } } #[tokio::main] async fn main() -> Result<()> { println!("Hello, agents!"); tracing_subscriber::fmt::init(); let openai = swiftide::integrations::gemini::Gemini::builder() .default_embed_model("gemini-embedding-exp-03-07") .default_prompt_model("gemini-2.0-flash") .build()?; let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); tokio::spawn(async move { while let Some(msg) = rx.recv().await { println!("{msg}"); } }); agents::Agent::builder() .llm(&openai) .tools(vec![search_code(), read_file(), guess_a_number()]) .before_all(move |_context| { // This is a hook that runs before any command is executed // No native async closures in Rust yet, so we have to use Box::pin Box::pin(async move { println!("Hello hook!"); Ok(()) }) }) // Every message added by the agent will be printed to stdout .on_new_message(move |_, msg| { let msg = msg.to_string(); let tx = tx.clone(); Box::pin(async move { tx.send(msg).unwrap(); Ok(()) }) }) .limit(5) .build()? .query("In what file can I find an example of a swiftide agent? When you are done guess a number and stop") .await?; Ok(()) } ================================================ FILE: examples/hybrid_search.rs ================================================ //! # [Swiftide] Hybrid search with qudrant //! //! This example demonstrates how to do hybrid search with Qdrant with Sparse vectors. //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing::{ self, EmbeddedField, loaders::FileLoader, transformers::{self, ChunkCode, MetadataQACode}, }, integrations::{fastembed::FastEmbed, openai, qdrant::Qdrant}, query::{self, answers, query_transformers, search_strategies::HybridSearch}, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); // Ensure all batching is consistent let batch_size = 64; let fastembed_sparse = FastEmbed::try_default_sparse().unwrap().to_owned(); let fastembed = FastEmbed::try_default().unwrap().to_owned(); // Set up openai with the mini model, which is great for indexing let openai = openai::OpenAI::builder() .default_prompt_model("gpt-4o-mini") .build() .unwrap(); // Set up qdrant and use the combined fields (metadata + chunks) for both sparse and dense // vectors let qdrant = Qdrant::builder() .batch_size(batch_size) .vector_size(384) .with_vector(EmbeddedField::Combined) .with_sparse_vector(EmbeddedField::Combined) .collection_name("swiftide-hybrid-example") .build()?; indexing::Pipeline::from_loader(FileLoader::new("swiftide-core/").with_extensions(&["rs"])) // Chunk fairly large as the context window is big .then_chunk(ChunkCode::try_for_language_and_chunk_size( "rust", 10..2048, )?) // Generate metadata on the code chunks to increase our chances of finding the right code .then(MetadataQACode::from_client(openai.clone()).build().unwrap()) .then_in_batch( transformers::SparseEmbed::new(fastembed_sparse.clone()).with_batch_size(batch_size), ) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(batch_size)) .then_store_with(qdrant.clone()) .run() .await?; // Use sophisticated model for our query let openai = openai::OpenAI::builder() .default_prompt_model("gpt-4o") .build() .unwrap(); let query_pipeline = query::Pipeline::from_search_strategy( // Return a large amount of documents because we have a large context window // By default it uses the Combined fields, no need to configure HybridSearch::default() .with_top_n(20) .with_top_k(20) .to_owned(), ) // Generate subquestions on the initial query to increase our query coverage .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai.clone(), )) // Generate the same embeddings we used for indexing .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_transform_query(query_transformers::SparseEmbed::from_client( fastembed_sparse.clone(), )) .then_retrieve(qdrant.clone()) // Answer with Simple, which either takes the documents as is (in this case), or any // transformations applied after querying .then_answer(answers::Simple::from_client(openai.clone())); let answer = query_pipeline .query("What are the different pipelines in Swiftide and how do they work? Provide an elaborate answer with examples.") .await .unwrap(); println!("{}", answer.answer()); // ## Different Pipelines in Swiftide and How They Work // // Swiftide offers multiple pipelines, notably the indexing pipeline and the query pipeline. The // functionality of these pipelines is enhanced using traits and components like transformers, // stream handlers, and more. Below we elaborate on the key components and how they become part // of the larger pipeline system: // // ### Indexing Pipeline // // 1. **Transformers**: // - **Transformer Trait**: Transforms single nodes into single nodes. Mainly used for // transforming data in a singular manner. // - **BatchableTransformer Trait**: Transforms a batch of nodes into a stream of nodes, // useful for bulk processing. // // ```rust // #[async_trait] // pub trait Transformer: Send + Sync { // async fn transform_node(&self, node: Node) -> Result; // fn concurrency(&self) -> Option { None } // } // // #[async_trait] // impl Transformer for F where F: Fn(Node) -> Result + Send + Sync { // async fn transform_node(&self, node: Node) -> Result { // self(node) // } // } // // #[async_trait] // pub trait BatchableTransformer: Send + Sync { // async fn batch_transform(&self, nodes: Vec) -> IndexingStream; // fn concurrency(&self) -> Option { None } // } // // #[async_trait] // impl BatchableTransformer for F where F: Fn(Vec) -> IndexingStream + Send + Sync // { async fn batch_transform(&self, nodes: Vec) -> IndexingStream { // self(nodes) // } // } // ``` // // 2. **Loaders**: // - Defines methods for converting a loader into an `IndexingStream`. // // ```rust // pub trait Loader { // fn into_stream(self) -> IndexingStream; // } // ``` // // 3. **Chunker Transformers**: // - Splits one node into multiple nodes. It's useful for breaking down large nodes into // smaller, manageable chunks. // // ```rust // #[async_trait] // pub trait ChunkerTransformer: Send + Sync + Debug { // async fn transform_node(&self, node: Node) -> IndexingStream; // fn concurrency(&self) -> Option { None } // } // ``` // // 4. **IndexingStream**: // - An asynchronous stream of nodes, used internally by the indexing pipeline to handle // streams of `Node` items. // // ```rust // pub struct IndexingStream { // #[pin] // pub(crate) inner: Pin> + Send>>, // } // ``` // // ### Query Pipeline // // 1. **QueryStream**: // - Handles query streams, ensuring data flows correctly through various query states. // // ```rust // pub struct QueryStream<'stream, Q: 'stream> { // #[pin] // pub(crate) inner: Pin>> + Send + 'stream>>, // #[pin] // pub sender: Option>>>, // } // ``` // // 2. **Query Handling**: // - Various state transitions and handling for queries in the pipeline. // // ```rust // pub struct Query { // original: String, // current: String, // state: State, // transformation_history: Vec, // pub embedding: Option, // pub sparse_embedding: Option, // } // ``` // // ### Extending the Pipeline with Traits // // Swiftide allows developers to extend the pipeline by implementing custom transformers, // loaders, and other components by implementing the respective traits. This design ensures // flexibility and modularity, allowing seamless integration of custom functionality. // // For example, to create a custom transformer: // ```rust // use crate::node::Node; // use anyhow::Result; // // struct MyCustomTransformer; // // #[async_trait] // impl Transformer for MyCustomTransformer { // async fn transform_node(&self, node: Node) -> Result { // // Custom transformation logic here... // Ok(node) // } // } // ``` // // ### Usage of Prompts in Transformers // // Swiftide utilizes the [`Template`] for templating prompts, making it easy to define and // manage prompts within transformers. // // ```rust // let template = PromptTemplate::try_compiled_from_str("hello {{world}}").await.unwrap(); // let prompt = template.to_prompt().with_context_value("world", "swiftide"); // assert_eq!(prompt.render().await.unwrap(), "hello swiftide"); // ``` // // ### Conclusion // // The Indexing and Query Pipelines in Swiftide are made extensible and modular via traits such // as `Transformer`, `BatchableTransformer`, `Loader`, and more. Custom implementations can // seamlessly integrate into the pipeline, providing flexibility in how data is processed, // transformed, and indexed. The use of prompts further enhances the capability to manage // dynamic and templated data within these pipelines. Ok(()) } ================================================ FILE: examples/index_codebase.rs ================================================ //! # [Swiftide] Indexing the Swiftide itself example //! //! This example demonstrates how to index the Swiftide codebase itself. //! Note that for it to work correctly you need to have OPENAI_API_KEY set, redis and qdrant //! running. //! //! The pipeline will: //! - Load all `.rs` files from the current directory //! - Skip any nodes previously processed; hashes are based on the path and chunk (not the //! metadata!) //! - Run metadata QA on each chunk; generating questions and answers and adding metadata //! - Chunk the code into pieces of 10 to 2048 bytes //! - Embed the chunks in batches of 10, Metadata is embedded by default //! - Store the nodes in Qdrant //! //! Note that metadata is copied over to smaller chunks when chunking. When making LLM requests //! with lots of small chunks, consider the rate limits of the API. //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing, indexing::LanguageModelWithBackOff, indexing::loaders::FileLoader, indexing::transformers::{ChunkCode, Embed, MetadataQACode}, integrations::{self, qdrant::Qdrant, redis::Redis}, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let openai_client = integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-small") .default_prompt_model("gpt-3.5-turbo") .build()?; // Optionally use the backoff decorator to handle rate limits and transient errors. // // This works with streaming as well, async openai does not support this properly yet. let openai_client = LanguageModelWithBackOff::new(openai_client, Default::default()); let redis_url = std::env::var("REDIS_URL") .as_deref() .unwrap_or("redis://localhost:6379") .to_owned(); indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) .filter_cached(Redis::try_from_url(redis_url, "swiftide-examples")?) .then(MetadataQACode::new(openai_client.clone())) .then_chunk(ChunkCode::try_for_language_and_chunk_size( "rust", 10..2048, )?) .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with( Qdrant::builder() .batch_size(50) .vector_size(1536) .collection_name("swiftide-examples") .build()?, ) .run() .await?; Ok(()) } ================================================ FILE: examples/index_codebase_reduced_context.rs ================================================ //! # [Swiftide] Indexing the Swiftide itself example with reduced context size //! //! This example demonstrates how to index the Swiftide codebase itself, optimizing for a smaller //! context size. Note that for it to work correctly you need to have OPENAI_API_KEY set, redis and //! qdrant running. //! //! The pipeline will: //! - Load all `.rs` files from the current directory //! - Skip any nodes previously processed; hashes are based on the path and chunk (not the //! metadata!) //! - Generate an outline of the symbols defined in each file to be used as context in a later step //! and store it in the metadata //! - Chunk the code into pieces of 10 to 2048 bytes //! - For each chunk, generate a condensed subset of the symbols outline tailored for that specific //! chunk and store that in the metadata //! - Run metadata QA on each chunk; generating questions and answers and adding metadata //! - Embed the chunks in batches of 10, Metadata is embedded by default //! - Store the nodes in Qdrant //! //! Note that metadata is copied over to smaller chunks when chunking. When making LLM requests //! with lots of small chunks, consider the rate limits of the API. //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::indexing; use swiftide::indexing::loaders::FileLoader; use swiftide::indexing::transformers::{ChunkCode, Embed, MetadataQACode}; use swiftide::integrations::{self, qdrant::Qdrant, redis::Redis}; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let openai_client = integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-small") .default_prompt_model("gpt-3.5-turbo") .build()?; let redis_url = std::env::var("REDIS_URL") .as_deref() .unwrap_or("redis://localhost:6379") .to_owned(); let chunk_size = 2048; indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) .filter_cached(Redis::try_from_url( redis_url, "swiftide-examples-codebase-reduced-context", )?) .then( indexing::transformers::OutlineCodeTreeSitter::try_for_language( "rust", Some(chunk_size), )?, ) .then(MetadataQACode::new(openai_client.clone())) .then_chunk(ChunkCode::try_for_language_and_chunk_size( "rust", 10..chunk_size, )?) .then(indexing::transformers::CompressCodeOutline::new( openai_client.clone(), )) .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with( Qdrant::builder() .batch_size(50) .vector_size(1536) .collection_name("swiftide-examples-codebase-reduced-context") .build()?, ) .run() .await?; Ok(()) } ================================================ FILE: examples/index_groq.rs ================================================ //! # [Swiftide] Indexing with Groq //! //! This example demonstrates how to index the Swiftide codebase itself. //! Note that for it to work correctly you need to have set the GROQ_API_KEY //! //! The pipeline will: //! - Loads the readme from the project //! - Chunk the code into pieces of 10 to 2048 bytes //! - Run metadata QA on each chunk with Groq; generating questions and answers and adding metadata //! - Embed the chunks in batches of 10, Metadata is embedded by default //! - Store the nodes in Memory Storage //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing, indexing::loaders::FileLoader, indexing::persist::MemoryStorage, indexing::transformers::{ChunkMarkdown, Embed, MetadataQAText}, integrations, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let groq_client = integrations::groq::Groq::builder() .default_prompt_model("llama3-8b-8192") .to_owned() .build()?; let fastembed = integrations::fastembed::FastEmbed::try_default()?; let memory_store = MemoryStorage::default(); indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(groq_client.clone())) .then_in_batch(Embed::new(fastembed).with_batch_size(10)) .then_store_with(memory_store.clone()) .run() .await?; println!("Example results:"); println!( "{}", memory_store .get_all_values() .await .into_iter() .flat_map(|n| n.metadata.into_values().map(|v| v.to_string())) .collect::>() .join("\n") ); Ok(()) } ================================================ FILE: examples/index_into_redis.rs ================================================ //! # [Swiftide] Indexing the Swiftide itself example //! //! This example demonstrates how to index the Swiftide codebase itself. //! Note that for it to work correctly you need to have OPENAI_API_KEY set, redis and qdrant //! running. //! //! The pipeline will: //! - Load all `.rs` files from the current directory //! - Skip any nodes previously processed; hashes are based on the path and chunk (not the //! metadata!) //! - Run metadata QA on each chunk; generating questions and answers and adding metadata //! - Chunk the code into pieces of 10 to 2048 bytes //! - Embed the chunks in batches of 10, Metadata is embedded by default //! - Store the nodes in Qdrant //! //! Note that metadata is copied over to smaller chunks when chunking. When making LLM requests //! with lots of small chunks, consider the rate limits of the API. //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing, indexing::loaders::FileLoader, indexing::transformers::ChunkCode, integrations::redis::Redis, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let redis_url = std::env::var("REDIS_URL") .as_deref() .unwrap_or("redis://localhost:6379") .to_owned(); indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) .then_chunk(ChunkCode::try_for_language_and_chunk_size( "rust", 10..2048, )?) .then_store_with( // By default the value is the full node serialized to JSON. // We can customize this by providing a custom function. Redis::try_build_from_url(&redis_url)? .persist_value_fn(|node| Ok(serde_json::to_string(&node.metadata)?)) .batch_size(50) .build()?, ) .run() .await?; Ok(()) } ================================================ FILE: examples/index_markdown_lots_of_metadata.rs ================================================ //! # [Swiftide] Indexing the Swiftide README with lots of metadata //! //! This example demonstrates how to index the Swiftide README with lots of metadata. //! //! The pipeline will: //! - Load the README.md file from the current directory //! - Chunk the file into pieces of 20 to 1024 bytes //! - Generate questions and answers for each chunk //! - Generate a summary for each chunk //! - Generate a title for each chunk //! - Generate keywords for each chunk //! - Embed each chunk //! - Store the nodes in Qdrant //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing, indexing::loaders::FileLoader, indexing::transformers::{ ChunkMarkdown, Embed, MetadataKeywords, MetadataQAText, MetadataSummary, MetadataTitle, }, integrations::{self, qdrant::Qdrant}, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let openai_client = integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-small") .default_prompt_model("gpt-4o") .build()?; indexing::Pipeline::from_loader(FileLoader::new("README.md").with_extensions(&["md"])) .with_concurrency(1) .then_chunk(ChunkMarkdown::from_chunk_range(20..2048)) .then(MetadataQAText::new(openai_client.clone())) .then(MetadataSummary::new(openai_client.clone())) .then(MetadataTitle::new(openai_client.clone())) .then(MetadataKeywords::new(openai_client.clone())) .then_in_batch(Embed::new(openai_client.clone())) .log_all() .filter_errors() .then_store_with( Qdrant::builder() .batch_size(50) .vector_size(1536) .collection_name("swiftide-examples") .build()?, ) .run() .await?; Ok(()) } ================================================ FILE: examples/index_md_into_pgvector.rs ================================================ /// This example demonstrates how to index markdown into PGVector use std::path::PathBuf; use swiftide::{ indexing::{ self, EmbeddedField, loaders::FileLoader, transformers::{ ChunkMarkdown, Embed, MetadataQAText, metadata_qa_text::NAME as METADATA_QA_TEXT_NAME, }, }, integrations::{self, fastembed::FastEmbed, pgvector::PgVector}, query::{self, answers, query_transformers, response_transformers}, traits::SimplePrompt, }; async fn ask_query( llm_client: impl SimplePrompt + Clone + 'static, embed: FastEmbed, vector_store: PgVector, questions: Vec, ) -> Result, Box> { // By default the search strategy is SimilaritySingleEmbedding // which takes the latest query, embeds it, and does a similarity search // // Pgvector will return an error if multiple embeddings are set // // The pipeline generates subquestions to increase semantic coverage, embeds these in a single // embedding, retrieves the default top_k documents, summarizes them and uses that as context // for the final answer. let pipeline = query::Pipeline::default() .then_transform_query(query_transformers::GenerateSubquestions::from_client( llm_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(embed)) .then_retrieve(vector_store.clone()) .then_transform_response(response_transformers::Summary::from_client( llm_client.clone(), )) .then_answer(answers::Simple::from_client(llm_client.clone())); let results: Vec = pipeline .query_all(questions) .await? .iter() .map(|result| result.answer().to_string()) .collect(); Ok(results) } #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); tracing::info!("Starting PgVector indexing test"); // Get the manifest directory path let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); // Create a PathBuf to test dataset from the manifest directory let test_dataset_path = PathBuf::from(manifest_dir).join("../README.md"); tracing::info!("Test Dataset path: {:?}", test_dataset_path); let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; tracing::info!("pgv_db_url :: {:#?}", pgv_db_url); let llm_client = integrations::ollama::Ollama::default() .with_default_prompt_model("llama3.2:latest") .to_owned(); let fastembed = integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed"); // Configure Pgvector with a default vector size, a single embedding // and in addition to embedding the text metadata, also store it in a field let pgv_storage = PgVector::builder() .db_url(pgv_db_url) .vector_size(384) .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_TEXT_NAME) .table_name("swiftide_pgvector_test".to_string()) .build() .unwrap(); // Drop the existing test table before running the test tracing::info!("Dropping existing test table & index if it exists"); let drop_table_sql = "DROP TABLE IF EXISTS swiftide_pgvector_test"; let drop_index_sql = "DROP INDEX IF EXISTS swiftide_pgvector_test_embedding_idx"; if let Ok(pool) = pgv_storage.get_pool().await { sqlx::query(drop_table_sql).execute(pool).await?; sqlx::query(drop_index_sql).execute(pool).await?; } else { return Err("Failed to get database connection pool".into()); } tracing::info!("Starting indexing pipeline"); indexing::Pipeline::from_loader(FileLoader::new(test_dataset_path).with_extensions(&["md"])) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(llm_client.clone())) .then_in_batch(Embed::new(fastembed.clone()).with_batch_size(100)) .then_store_with(pgv_storage.clone()) .run() .await?; tracing::info!("PgVector Indexing completed successfully"); let questions: Vec = vec![ "What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.".into(), "How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.".into(), ]; ask_query( llm_client.clone(), fastembed.clone(), pgv_storage.clone(), questions, ) .await? .iter() .enumerate() .for_each(|(i, result)| { tracing::info!("*** Answer Q{} ***", i + 1); tracing::info!("{}", result); tracing::info!("===X==="); }); tracing::info!("PgVector Indexing & retrieval test completed successfully"); Ok(()) } ================================================ FILE: examples/index_ollama.rs ================================================ //! # [Swiftide] Indexing with Ollama //! //! This example demonstrates how to index the Swiftide codebase itself. //! Note that for it to work correctly you need to have ollama running on the default local port. //! //! The pipeline will: //! - Loads the readme from the project //! - Chunk the code into pieces of 10 to 2048 bytes //! - Run metadata QA on each chunk with Ollama; generating questions and answers and adding //! metadata //! - Embed the chunks in batches of 10, Metadata is embedded by default //! - Store the nodes in Memory Storage //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing, indexing::loaders::FileLoader, indexing::persist::MemoryStorage, indexing::transformers::{ChunkMarkdown, Embed, MetadataQAText}, integrations, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let ollama_client = integrations::ollama::Ollama::default() .with_default_prompt_model("llama3.1") .to_owned(); let fastembed = integrations::fastembed::FastEmbed::try_default()?; let memory_store = MemoryStorage::default(); indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(ollama_client.clone())) .then_in_batch(Embed::new(fastembed).with_batch_size(10)) .then_store_with(memory_store.clone()) .run() .await?; println!("Example results:"); println!( "{}", memory_store .get_all_values() .await .into_iter() .flat_map(|n| n.metadata.into_values().map(|v| v.to_string())) .collect::>() .join("\n") ); Ok(()) } ================================================ FILE: examples/kafka.rs ================================================ //! # [Swiftide] Loading data from Kafka //! //! This example demonstrates how to index data from a Kafka topic and store the data in another //! Kafka topic. Note that for it to work correctly you need to have kafka. //! //! The pipeline will: //! - Load messages from a Kafka topic //! - Embed the chunks in batches of 10 //! - Store the nodes in kafka //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing::{self, transformers::Embed}, integrations::{ fastembed::FastEmbed, kafka::{ClientConfig, Kafka}, }, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); static LOADER_TOPIC: &str = "loader"; static STORAGE_TOPIC: &str = "storage"; let mut client_config = ClientConfig::new(); client_config.set("bootstrap.servers", "localhost:9092"); client_config.set("group.id", "group_id"); client_config.set("auto.offset.reset", "earliest"); let loader = Kafka::builder() .client_config(client_config.clone()) .topic(LOADER_TOPIC) .build() .unwrap(); let storage = Kafka::builder() .client_config(client_config) .topic(STORAGE_TOPIC) .create_topic_if_not_exists(true) .batch_size(2usize) .build() .unwrap(); indexing::Pipeline::from_loader(loader) .then_in_batch(Embed::new(FastEmbed::try_default().unwrap()).with_batch_size(10)) .then_store_with(storage) .run() .await?; Ok(()) } ================================================ FILE: examples/lancedb.rs ================================================ /// This example demonstrates how to use the LanceDB integration with Swiftide use swiftide::{ indexing::{ self, EmbeddedField, loaders::FileLoader, transformers::{ ChunkMarkdown, Embed, MetadataQAText, metadata_qa_text::NAME as METADATA_QA_TEXT_NAME, }, }, integrations::{self, lancedb::LanceDB}, query::{self, answers, query_transformers, response_transformers}, }; use temp_dir::TempDir; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let openai_client = integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-small") .default_prompt_model("gpt-4o-mini") .build()?; let tempdir = TempDir::new().unwrap(); // Configure lancedb with a default vector size, a single embedding // and in addition to embedding the text metadata, also store it in a field let lancedb = LanceDB::builder() .uri(tempdir.child("lancedb").to_str().unwrap()) .vector_size(1536) .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_TEXT_NAME) .table_name("swiftide_test") .build() .unwrap(); indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(openai_client.clone())) .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with(lancedb.clone()) .run() .await?; // By default the search strategy is SimilaritySingleEmbedding // which takes the latest query, embeds it, and does a similarity search // // LanceDB will return an error if multiple embeddings are set // // The pipeline generates subquestions to increase semantic coverage, embeds these in a single // embedding, retrieves the default top_k documents, summarizes them and uses that as context // for the final answer. let pipeline = query::Pipeline::default() .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client( openai_client.clone(), )) .then_retrieve(lancedb.clone()) .then_transform_response(response_transformers::Summary::from_client( openai_client.clone(), )) .then_answer(answers::Simple::from_client(openai_client.clone())); let result = pipeline .query("What is swiftide? Please provide an elaborate explanation") .await?; println!("{:?}", result.answer()); Ok(()) } ================================================ FILE: examples/langfuse.rs ================================================ //! This is an example of using the langfuse integration with Swiftide. //! //! Langfuse is a platform for tracking and monitoring LLM usage and performance. //! //! When the feature `langfuse` is enabled, Swiftide can report tracing information, //! usage, inputs, and outputs to langfuse. //! //! For this to work, you need to set the LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY //! to the appropriate values. You can also set the LANGFUSE_URL environment variable //! to overwrite the default URL (http://localhost:3000). //! //! You can find more information about langfuse at https://langfuse.com/. On their github they //! also have a handy docker compose setup. //! //! More advanced usage is possible by using the `LangfuseLayer` directly. use anyhow::Result; use swiftide::traits::SimplePrompt; use tracing::level_filters::LevelFilter; use tracing_subscriber::{ EnvFilter, Layer as _, layer::SubscriberExt as _, util::SubscriberInitExt as _, }; #[tokio::main] async fn main() -> Result<()> { println!("Hello, langfuse!"); let fmt_layer = tracing_subscriber::fmt::layer() .compact() .with_target(false) .boxed(); let langfuse_layer = swiftide::langfuse::LangfuseLayer::default() .with_filter(LevelFilter::DEBUG) .boxed(); let registry = tracing_subscriber::registry() .with(EnvFilter::from_default_env()) .with(vec![fmt_layer, langfuse_layer]); registry.init(); prompt_openai().await?; Ok(()) } #[tracing::instrument] async fn prompt_openai() -> Result<()> { let openai = swiftide::integrations::openai::OpenAI::builder() .default_prompt_model("gpt-5") .build() .unwrap(); let paris = openai .prompt("What is the capital of France?".into()) .await?; println!("The capital of France is {paris}"); Ok(()) } ================================================ FILE: examples/query_pipeline.rs ================================================ use swiftide::{ indexing::{ self, loaders::FileLoader, transformers::{ChunkMarkdown, Embed, MetadataQAText}, }, integrations::{self, qdrant::Qdrant}, query::{self, answers, query_transformers, response_transformers}, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let openai_client = integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-large") .default_prompt_model("gpt-4o") .build()?; let qdrant = Qdrant::builder() .batch_size(50) .vector_size(3072) .collection_name("swiftide-examples") .build()?; indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(openai_client.clone())) .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with(qdrant.clone()) .run() .await?; // By default the search strategy is SimilaritySingleEmbedding // which takes the latest query, embeds it, and does a similarity search let pipeline = query::Pipeline::default() .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client( openai_client.clone(), )) .then_retrieve(qdrant.clone()) .then_transform_response(response_transformers::Summary::from_client( openai_client.clone(), )) .then_answer(answers::Simple::from_client(openai_client.clone())); let result = pipeline .query("What is swiftide? Please provide an elaborate explanation") .await?; println!("{:?}", result.answer()); Ok(()) } ================================================ FILE: examples/reranking.rs ================================================ /// Demonstrates reranking retrieved documents with fastembed /// /// When reranking, many more documents are retrieved than used for the initial query. Maybe /// even from multiple sources. /// /// Reranking compares the relevancy of the documents with the initial query, then filters out /// the `top_k` documents. /// /// By default the model uses 'bge-reranker-base'. use swiftide::{ indexing::{ self, loaders::FileLoader, transformers::{ChunkMarkdown, Embed}, }, integrations::{self, fastembed, qdrant::Qdrant}, query::{self, answers, query_transformers}, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let openai_client = integrations::openai::OpenAI::builder() .default_prompt_model("gpt-4o") .build()?; let fastembed = fastembed::FastEmbed::builder().batch_size(10).build()?; let reranker = fastembed::Rerank::builder().top_k(5).build()?; let qdrant = Qdrant::builder() .batch_size(50) .vector_size(384) .collection_name("swiftide-reranking") .build()?; indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then_in_batch(Embed::new(fastembed.clone())) .then_store_with(qdrant.clone()) .run() .await?; // By default the search strategy is SimilaritySingleEmbedding // which takes the latest query, embeds it, and does a similarity search let pipeline = query::Pipeline::default() .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_retrieve(qdrant.clone()) .then_transform_response(reranker) .then_answer(answers::Simple::from_client(openai_client.clone())); let result = pipeline .query("What is swiftide? Please provide an elaborate explanation") .await?; println!("{:?}", result.answer()); Ok(()) } ================================================ FILE: examples/responses_api.rs ================================================ use anyhow::{Context, Result}; use futures_util::StreamExt as _; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::io::Write as _; use swiftide::{ chat_completion::{ChatCompletionRequest, ChatMessage, ToolOutput, errors::ToolError}, integrations::openai::{OpenAI, Options}, traits::{AgentContext, ChatCompletion, SimplePrompt, StructuredPrompt}, }; use tracing_subscriber::EnvFilter; #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] #[allow(dead_code)] struct WeatherSummary { description: String, } #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] struct EchoArgs { message: String, } /// Minimal echo tool used to demonstrate tool calling with the Responses API. /// The macro implements the `Tool` trait, derives the JSON schema, and generates /// a helper constructor (`echo_tool()`) that returns a boxed tool ready for use. #[swiftide::tool( description = "Echos the provided message back to the caller.", param(name = "payload", description = "Text to echo back") )] async fn echo_tool( _context: &dyn AgentContext, payload: EchoArgs, ) -> Result { Ok(ToolOutput::text(format!("Echo: {}", payload.message))) } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .init(); let openai = OpenAI::builder() .default_prompt_model("gpt-4.1-mini") .default_options(Options::builder().temperature(0.2)) .use_responses_api(true) .build()?; let greeting = openai .prompt("Say hello in one short sentence".into()) .await?; println!("Prompt result: {greeting}"); let structured: WeatherSummary = openai .structured_prompt("Summarise today's weather in Amsterdam as JSON".into()) .await?; println!("Structured result: {structured:?}"); let chat_request = ChatCompletionRequest::builder() .messages(vec![ ChatMessage::new_system("You are a concise assistant."), ChatMessage::new_user("Share one fun fact about Amsterdam."), ]) .build()?; let completion = openai.complete(&chat_request).await?; println!( "Complete result: {}", completion.message().unwrap_or("") ); let mut stream = openai.complete_stream(&chat_request).await; print!("Streaming result: "); let mut streamed_message = String::new(); while let Some(chunk) = stream.next().await { let chunk = chunk?; if let Some(delta) = chunk .delta .as_ref() .and_then(|delta| delta.message_chunk.as_deref()) { print!("{delta}"); std::io::stdout().flush().ok(); } if let Some(message) = chunk.message() { streamed_message = message.to_string(); } } println!(); if streamed_message.is_empty() { println!("Full streamed result: "); } else { println!("Full streamed result: {streamed_message}"); } let tool_request = ChatCompletionRequest::builder() .messages(vec![ ChatMessage::new_system( "You are a precise assistant. Use available tools before replying directly.", ), ChatMessage::new_user( "Call the echo tool with the phrase \"Hello Responses API\" and then summarise the result.", ), ]) .tool(echo_tool()) .build()?; let tool_completion = openai.complete(&tool_request).await?; if let Some(tool_call) = tool_completion .tool_calls() .and_then(|calls| calls.first()) .cloned() { println!( "Assistant requested tool `{}` with arguments {}", tool_call.name(), tool_call.args().unwrap_or("") ); let args_json = tool_call .args() .context("echo tool call missing arguments")?; let args: EchoToolArgs = serde_json::from_str(args_json)?; let tool_output = format!("Echo: {}", args.payload.message); let mut follow_up_messages = tool_request.messages().to_vec(); follow_up_messages.push(ChatMessage::new_assistant( None::, Some(vec![tool_call.clone()]), )); follow_up_messages.push(ChatMessage::new_tool_output( tool_call.clone(), ToolOutput::text(tool_output), )); let follow_up_request = ChatCompletionRequest::builder() .messages(follow_up_messages) .tool(echo_tool()) .build()?; let final_completion = openai.complete(&follow_up_request).await?; println!( "Final response after tool call: {}", final_completion.message().unwrap_or("") ); } else { println!( "Assistant responded without tool calls: {}", tool_completion.message().unwrap_or("") ); } Ok(()) } ================================================ FILE: examples/responses_api_reasoning.rs ================================================ //! Simple agent example that enables reasoning summaries via the Responses API. use anyhow::Result; use swiftide::agents::Agent; use swiftide::chat_completion::{ChatMessage, ReasoningItem}; use swiftide::integrations::openai::{OpenAI, Options, ReasoningEffort}; use tracing_subscriber::EnvFilter; fn reasoning_summary(reasoning: Option<&[ReasoningItem]>) -> Option { let summary = reasoning .unwrap_or(&[]) .iter() .flat_map(|item| item.summary.iter()) .cloned() .collect::>() .join("\n"); if summary.is_empty() { None } else { Some(summary) } } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .init(); // Reasoning models require the Responses API. Enabling reasoning effort also asks for a // summary and encrypted reasoning content (enabled by default). If your OpenAI org is not // verified for reasoning access, summaries may be absent. Disable with // `reasoning_features(false)` if desired. let openai = OpenAI::builder() .default_prompt_model("o3-mini") .default_options(Options::builder().reasoning_effort(ReasoningEffort::Low)) .use_responses_api(true) .build()?; let mut agent = Agent::builder() .llm(&openai) .on_new_message(|_, message| { if let ChatMessage::Assistant(content, _) = message && let Some(content) = content.as_deref() { println!("Assistant: {content}"); } Box::pin(async move { Ok(()) }) }) .after_completion(|_, response| { if let Some(summary) = reasoning_summary(response.reasoning.as_deref()) { println!("Reasoning summary:\n{summary}"); } let has_encrypted = response .reasoning .as_ref() .is_some_and(|items| items.iter().any(|item| item.encrypted_content.is_some())); println!("Encrypted reasoning content present: {has_encrypted}"); Box::pin(async move { Ok(()) }) }) .build()?; agent .query("Explain why the sky is blue in one short paragraph.") .await?; Ok(()) } ================================================ FILE: examples/scraping_index_to_markdown.rs ================================================ //! # [Swiftide] Indexing the Swiftide README with lots of metadata //! //! This example demonstrates how to index the Swiftide README with lots of metadata. //! //! The pipeline will: //! - Scrape the Bosun website //! - Transform the html to markdown //! - Chunk the markdown into smaller pieces //! - Store the nodes in Memory //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use spider::website::Website; use swiftide::{ indexing, indexing::persist::MemoryStorage, indexing::transformers::ChunkMarkdown, integrations::scraping::{HtmlToMarkdownTransformer, ScrapingLoader}, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); indexing::Pipeline::from_loader(ScrapingLoader::from_spider( Website::new("https://www.bosun.ai/") .with_limit(1) .to_owned(), )) .then(HtmlToMarkdownTransformer::default()) .then_chunk(ChunkMarkdown::from_chunk_range(20..2048)) .log_all() .then_store_with(MemoryStorage::default()) .run() .await?; Ok(()) } ================================================ FILE: examples/stop_with_args_custom_schema.rs ================================================ //! Demonstrates how to plug a custom JSON schema into the stop tool for an OpenAI-powered agent. //! //! Set the `OPENAI_API_KEY` environment variable before running the example. The agent guides the //! model to call the `stop` tool with a structured payload that matches the schema defined below. //! The on-stop hook prints the structured payload that made the agent stop. use anyhow::Result; use schemars::{JsonSchema, Schema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::to_string_pretty; use swiftide::agents::tools::control::StopWithArgs; use swiftide::agents::{Agent, StopReason}; use swiftide::traits::Tool; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] enum TaskStatus { Succeeded, Failed, Cancelled, } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[schemars(deny_unknown_fields)] struct StopPayload { status: TaskStatus, summary: String, #[serde(default, skip_serializing_if = "Option::is_none")] details: Option, } fn stop_schema() -> Schema { schema_for!(StopPayload) } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let schema = stop_schema(); let stop_tool = StopWithArgs::with_parameters_schema(schema.clone()); println!( "stop tool schema:\n{}", to_string_pretty(&stop_tool.tool_spec())?, ); let openai = swiftide::integrations::openai::OpenAI::builder() .default_prompt_model("gpt-4o-mini") .default_embed_model("text-embedding-3-small") .build()?; let mut builder = Agent::builder(); builder .llm(&openai) .without_default_stop_tool() .tools([stop_tool.clone()]) .on_stop(|_, reason, _| { Box::pin(async move { if let StopReason::RequestedByTool(_, payload) = reason && let Some(payload) = payload { println!( "agent stopped with structured payload:\n{}", to_string_pretty(&payload).unwrap_or_else(|_| payload.to_string()), ); } Ok(()) }) }); if let Some(prompt) = builder.system_prompt_mut() { prompt .with_role("Workflow finisher") .with_guidelines([ "Summarize the work that was just completed and recommend next actions.", "When you are done, call the `stop` tool using the provided JSON schema.", "Always include the `details` field; use null when there is nothing to add.", ]) .with_constraints(["Never fabricate task status values outside the schema."]); } let mut agent = builder.build()?; agent .query_once( "You completed onboarding five merchants today. Prepare a final handoff report and stop.", ) .await?; Ok(()) } ================================================ FILE: examples/store_multiple_vectors.rs ================================================ //! # [Swiftide] Ingesting file with multiple metadata stored as named vectors //! //! This example demonstrates how to ingest a LICENSE file, generate multiple metadata, and store it //! all in Qdrant with individual named vectors //! //! The pipeline will: //! - Load the LICENSE file from the current directory //! - Chunk the file into pieces of 20 to 1024 bytes //! - Generate questions and answers for each chunk //! - Generate a summary for each chunk //! - Generate a title for each chunk //! - Generate keywords for each chunk //! - Embed each chunk //! - Embed each metadata //! - Store the nodes in Qdrant with chunk and metadata embeds as named vectors //! //! [Swiftide]: https://github.com/bosun-ai/swiftide //! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples use swiftide::{ indexing::loaders::FileLoader, indexing::transformers::{ ChunkMarkdown, Embed, MetadataKeywords, MetadataQAText, MetadataSummary, MetadataTitle, metadata_keywords, metadata_qa_text, metadata_summary, metadata_title, }, indexing::{self, EmbedMode, EmbeddedField}, integrations::{ self, qdrant::{Distance, Qdrant, VectorConfig}, }, }; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); let openai_client = integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-small") .default_prompt_model("gpt-4o") .build()?; indexing::Pipeline::from_loader(FileLoader::new("LICENSE")) .with_concurrency(1) .with_embed_mode(EmbedMode::PerField) .then_chunk(ChunkMarkdown::from_chunk_range(20..2048)) .then(MetadataQAText::new(openai_client.clone())) .then(MetadataSummary::new(openai_client.clone())) .then(MetadataTitle::new(openai_client.clone())) .then(MetadataKeywords::new(openai_client.clone())) .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .log_all() .filter_errors() .then_store_with( Qdrant::builder() .batch_size(50) .vector_size(1536) .collection_name("swiftide-multi-vectors") .with_vector(EmbeddedField::Chunk) .with_vector(EmbeddedField::Metadata(metadata_qa_text::NAME.into())) .with_vector(EmbeddedField::Metadata(metadata_summary::NAME.into())) .with_vector( VectorConfig::builder() .embedded_field(EmbeddedField::Metadata(metadata_title::NAME.into())) .distance(Distance::Manhattan) .build()?, ) .with_vector(EmbeddedField::Metadata(metadata_keywords::NAME.into())) .build()?, ) .run() .await?; Ok(()) } ================================================ FILE: examples/streaming_agents.rs ================================================ //! This example demonstrates how to stream responses from an agent //! //! By default, for convenience the accumulated response is streamed. You can opt-out of this //! behaviour and only receive the delta as well (only with OpenAI-like providers). use anyhow::Result; use swiftide::agents; #[tokio::main] async fn main() -> Result<()> { let openai = swiftide::integrations::openai::OpenAI::builder() .default_embed_model("text-embeddings-3-small") .default_prompt_model("gpt-4o-mini") // Only streams the delta, leave this out to stream the full response .stream_full(false) .build()?; // let anthropic = swiftide::integrations::anthropic::Anthropic::builder() // .default_prompt_model("claude-3-7-sonnet-latest") // .build()?; agents::Agent::builder() .llm(&openai) .on_stream(|_agent, response| { // We print the message chunk if it exists. Streamed responses also include // the full response (without tool calls) in `message` and an `id` to map them to // previous chunks for convenience. // // The agent uses the full assembled response at the end of the stream. if let Some(delta) = &response.delta { print!( "{}", delta .message_chunk .as_deref() .map(str::to_string) .unwrap_or_default() ); }; // If `stream_full` is disabled, response.message() will be the accumulated response // response.message() Box::pin(async move { Ok(()) }) }) // Every message added by the agent will be printed to stdout .on_new_message(move |_, msg| { let msg = msg.to_string(); Box::pin(async move { println!("\n---\nFinal message:\n {msg}"); Ok(()) }) }) .limit(5) .build()? .query("Why is the rust programming language so good?") .await?; Ok(()) } ================================================ FILE: examples/structured_prompt.rs ================================================ use anyhow::Result; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use swiftide::{integrations, traits::DynStructuredPrompt, traits::StructuredPrompt as _}; #[tokio::main] async fn main() -> Result<()> { let client = integrations::openai::OpenAI::builder() .default_prompt_model("gpt-5-mini") .build()?; // Note that deny unknown fields is required. If you get an error on 'additionalProperties' to // be required, and false, this is what is missing. #[derive(Deserialize, JsonSchema, Serialize, Debug)] #[serde(deny_unknown_fields)] struct MyResponse { questions: Vec, } let response = client .structured_prompt::( "List three interesting questions about the Rust programming language.".into(), ) .await?; println!("Response: {:?}", response.questions); // Because we use generics, structured_prompt is not dyn safe. However, there is an // alternative: let client: Box = Box::new(client); let response: serde_json::Value = client .structured_prompt_dyn( "List three interesting questions about the Rust programming language.".into(), schemars::schema_for!(MyResponse), ) .await?; let parsed: MyResponse = serde_json::from_value(response)?; println!("Response: {:?}", parsed); Ok(()) } ================================================ FILE: examples/tasks.rs ================================================ //! This example illustrates how to set up a basic tasks //! //! Tasks follow a graph model where each output of a node must match the input of the next node. //! //! To set up a task, you register nodes that implement the `TaskNode` trait. Most swiftide //! primiteves implement this trait, including agents, prompts, and closures. //! //! Then each node can be connected to the next node using the `register_transition` method. There //! is also a `register_transition_async` method that allows you to register an async transition. //! //! Since running an autonomous agent in a task is subject to taste, there is a basic //! `TaskAgent` that wraps it in an `Arc`, but your own implementation might want to toy //! with the state instead of the task instead. //! //! The API for closures as task nodes is still a bit clunky and subject to change. use anyhow::Result; use swiftide::{ agents::{ self, tasks::{closures::SyncFn, impls::TaskAgent, task::Task}, }, prompt::Prompt, }; #[tokio::main] async fn main() -> Result<()> { println!("Hello, agents!"); let openai = swiftide::integrations::openai::OpenAI::builder() .default_embed_model("text-embeddings-3-small") .default_prompt_model("gpt-4o-mini") .build()?; let agent = agents::Agent::builder().llm(&openai).build()?; let mut task: Task = Task::new(); let agent_id = task.register_node(TaskAgent::from(agent)); let hello_id = task.register_node(SyncFn::new(move |_context: &()| { println!("Hello from a task!"); Ok(()) })); task.starts_with(agent_id); // Async is also supported task.register_transition_async(agent_id, move |context| { Box::pin(async move { hello_id.transitions_with(context) }) })?; task.register_transition(hello_id, task.transitions_to_done())?; task.run("Hello there!").await?; Ok(()) } ================================================ FILE: examples/tool_custom_schema.rs ================================================ use std::borrow::Cow; use anyhow::Result; use schemars::{JsonSchema, Schema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::Value; use swiftide::chat_completion::{Tool, ToolCall, ToolOutput, ToolSpec, errors::ToolError}; use swiftide::traits::AgentContext; #[derive(Clone)] struct WorkflowTool; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[schemars( title = "WorkflowInstruction", description = "Choose a workflow action and optional payload", deny_unknown_fields )] struct WorkflowInstruction { #[schemars(description = "Which workflow action to execute")] action: WorkflowAction, #[schemars(description = "Optional payload forwarded to the workflow engine")] #[serde(default, skip_serializing_if = "Option::is_none")] payload: Option, } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "lowercase")] enum WorkflowAction { Start, Stop, Status, } #[swiftide::reexports::async_trait::async_trait] impl Tool for WorkflowTool { async fn invoke( &self, _agent_context: &dyn AgentContext, _tool_call: &ToolCall, ) -> Result { Ok(ToolOutput::text( "Workflow execution not implemented in this example", )) } fn name<'tool>(&'tool self) -> Cow<'tool, str> { Cow::Borrowed("workflow_tool") } fn tool_spec(&self) -> ToolSpec { ToolSpec::builder() .name("workflow_tool") .description("Executes a workflow action with strict input choices") .parameters_schema(workflow_schema()) .build() .expect("tool spec should be valid") } } fn workflow_schema() -> Schema { schema_for!(WorkflowInstruction) } fn main() -> Result<()> { let tool = WorkflowTool; let spec = tool.tool_spec(); println!( "{}", serde_json::to_string_pretty(&spec).expect("tool spec should serialize"), ); Ok(()) } ================================================ FILE: examples/usage_metrics.rs ================================================ //! Swiftide can emit usage metrics using `metrics-rs`. //! //! For metrics to be emitted, the `metrics` feature must be enabled. //! //! `metrics-rs` is a flexibly rust library that allows you to collect and publish metrics //! anywhere. From the user side, you need to provide a recorder and handles. The library itself //! provides several built-in for these, i.e. prometheus. //! //! In this example, we're indexing markdown and logging the usage metrics to stdout. For the //! recording we're using the examples from metric-rs. //! //! Usage metrics are emitted embedding, prompt requests, and chat completions. They always include //! the model used as metadata use swiftide::{ indexing::{ self, loaders::FileLoader, transformers::{ChunkMarkdown, Embed}, }, integrations::{self, qdrant::Qdrant}, }; #[tokio::main] async fn main() -> Result<(), Box> { // tracing_subscriber::fmt::init(); init_print_logger(); let metric_metadata = HashMap::from([("example".to_string(), "metadata".to_string())]); let openai_client = integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-small") .default_prompt_model("gpt-4.1-nano") // Metadata will be added to every metric .metric_metadata(metric_metadata) .build()?; indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["md"])) .then_chunk(ChunkMarkdown::from_chunk_range(10..512)) .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with( Qdrant::builder() .batch_size(50) .vector_size(1536) .collection_name("swiftide-examples-metrics") .build()?, ) .run() .await?; Ok(()) // (counter) registered key swiftide.usage.prompt_tokens with unit None and description "token // usage for the prompt" (counter) registered key swiftide.usage.completion_tokens with unit // None and description "token usage for the completion" (counter) registered key // swiftide.usage.total_tokens with unit None and description "total token usage" // counter increment for 'Key(swiftide.usage.prompt_tokens, [example = metadata, model = // text-embedding-3-small])': 356 counter increment for // 'Key(swiftide.usage.completion_tokens, [example = metadata, model = // text-embedding-3-small])': 0 counter increment for 'Key(swiftide.usage.total_tokens, // [example = metadata, model = text-embedding-3-small])': 356 counter increment for // 'Key(swiftide.usage.prompt_tokens, [example = metadata, model = text-embedding-3-small])': // 336 counter increment for 'Key(swiftide.usage.completion_tokens, [example = metadata, // model = text-embedding-3-small])': 0 counter increment for // 'Key(swiftide.usage.total_tokens, [example = metadata, model = text-embedding-3-small])': 336 // counter increment for 'Key(swiftide.usage.prompt_tokens, [example = metadata, model = // text-embedding-3-small])': 251 counter increment for // 'Key(swiftide.usage.completion_tokens, [example = metadata, model = // text-embedding-3-small])': 0 counter increment for 'Key(swiftide.usage.total_tokens, // [example = metadata, model = text-embedding-3-small])': 251 counter increment for // 'Key(swiftide.usage.prompt_tokens, [example = metadata, model = text-embedding-3-small])': // 404 counter increment for 'Key(swiftide.usage.completion_tokens, [example = metadata, // model = text-embedding-3-small])': 0 counter increment for // 'Key(swiftide.usage.total_tokens, [example = metadata, model = text-embedding-3-small])': 404 // counter increment for 'Key(swiftide.usage.prompt_tokens, [example = metadata, model = // text-embedding-3-small])': 329 counter increment for // 'Key(swiftide.usage.completion_tokens, [example = metadata, model = // text-embedding-3-small])': 0 counter increment for 'Key(swiftide.usage.total_tokens, // [example = metadata, model = text-embedding-3-small])': 329 } // --- Copied from https://github.com/metrics-rs/metrics/blob/main/metrics/examples/basic.rs use std::{collections::HashMap, sync::Arc}; use metrics::{Counter, CounterFn, Gauge, GaugeFn, Histogram, HistogramFn, Key, Recorder, Unit}; use metrics::{KeyName, Metadata, SharedString}; #[derive(Clone, Debug)] struct PrintHandle(Key); impl CounterFn for PrintHandle { fn increment(&self, value: u64) { println!("counter increment for '{}': {}", self.0, value); } fn absolute(&self, value: u64) { println!("counter absolute for '{}': {}", self.0, value); } } impl GaugeFn for PrintHandle { fn increment(&self, value: f64) { println!("gauge increment for '{}': {}", self.0, value); } fn decrement(&self, value: f64) { println!("gauge decrement for '{}': {}", self.0, value); } fn set(&self, value: f64) { println!("gauge set for '{}': {}", self.0, value); } } impl HistogramFn for PrintHandle { fn record(&self, value: f64) { println!("histogram record for '{}': {}", self.0, value); } } #[derive(Debug)] struct PrintRecorder; impl Recorder for PrintRecorder { fn describe_counter(&self, key_name: KeyName, unit: Option, description: SharedString) { println!( "(counter) registered key {} with unit {:?} and description {:?}", key_name.as_str(), unit, description ); } fn describe_gauge(&self, key_name: KeyName, unit: Option, description: SharedString) { println!( "(gauge) registered key {} with unit {:?} and description {:?}", key_name.as_str(), unit, description ); } fn describe_histogram(&self, key_name: KeyName, unit: Option, description: SharedString) { println!( "(histogram) registered key {} with unit {:?} and description {:?}", key_name.as_str(), unit, description ); } fn register_counter(&self, key: &Key, _metadata: &Metadata<'_>) -> Counter { Counter::from_arc(Arc::new(PrintHandle(key.clone()))) } fn register_gauge(&self, key: &Key, _metadata: &Metadata<'_>) -> Gauge { Gauge::from_arc(Arc::new(PrintHandle(key.clone()))) } fn register_histogram(&self, key: &Key, _metadata: &Metadata<'_>) -> Histogram { Histogram::from_arc(Arc::new(PrintHandle(key.clone()))) } } fn init_print_logger() { metrics::set_global_recorder(PrintRecorder).unwrap() } ================================================ FILE: release-plz.toml ================================================ [workspace] changelog_path = "./CHANGELOG.md" #changelog_config = "cliff.toml" git_tag_name = "v{{ version }}" changelog_update = false git_tag_enable = false git_release_enable = false repo_url = "https://github.com/bosun-ai/swiftide" [[package]] name = "swiftide-macros" publish_no_verify = true [[package]] # Only release the main package on github name = "swiftide" git_tag_name = "v{{ version }}" git_tag_enable = true git_release_enable = true changelog_include = [ "swiftide-core", "swiftide-indexing", "swiftide-integrations", "swiftide-query", "swiftide-test-utils", "swiftide-agents", "swiftide-macros", ] changelog_update = true [changelog] commit_parsers = [ { message = "^feat*", group = "New features" }, { message = "^fix*", group = "Bug fixes" }, { message = "^perf*", group = "Performance" }, { message = "^chore*", group = "Miscellaneous" }, { message = "^refactor*", group = "Miscellaneous" }, ] # changelog header header = """ # Changelog All notable changes to this project will be documented in this file. """ body = """ {%- if not version %} ## [unreleased] {% else -%} ## [{{ version }}]({{ release_link }}) - {{ timestamp | date(format="%Y-%m-%d") }} {% endif -%} {% macro commit(commit) -%} {% if commit.id -%} - [{{ commit.id | truncate(length=7, end="") }}]({{ "https://github.com/bosun-ai/swiftide/commit/" ~ commit.id }}) \ {% endif -%} {% if commit.scope %}*({{commit.scope | default(value = "uncategorized") | lower }})* {% endif %}\ {%- if commit.breaking %} [**breaking**]{% endif %} \ {{ commit.message | upper_first | trim }}\ {%- if commit.links %} \ in {% for link in commit.links %}[{{link.text}}]({{link.href}}) {% endfor -%}\ {% endif %} {%- if commit.breaking_description %} **BREAKING CHANGE**: {{ commit.breaking_description }} {%- endif %} {% endmacro -%} {% for group, commits in commits | group_by(attribute="group") %} ### {{ group | striptags | trim | upper_first }} {% for commit in commits | filter(attribute="scope") | sort(attribute="scope") %} {{ self::commit(commit=commit) }} {%- endfor -%} {% for commit in commits %} {%- if not commit.scope %} {{ self::commit(commit=commit) }} {%- endif -%} {%- endfor -%} {%- endfor %} {%- if github.contributors -%} {% if github.contributors | filter(attribute="is_first_time", value=true) | length != 0 %} ### New Contributors {%- endif %}\ {% for contributor in github.contributors | filter(attribute="is_first_time", value=true) %} * @{{ contributor.username }} made their first contribution {%- if contributor.pr_number %} in \ [#{{ contributor.pr_number }}]({{ self::remote_url() }}/pull/{{ contributor.pr_number }}) \ {%- endif %} {%- endfor -%} {% endif -%} {% if version %} {% if previous.version %} **Full Changelog**: {{ self::remote_url() }}/compare/{{ previous.version }}...{{ version }} {% endif %} {% else -%} {% raw %}\n{% endraw %} {% endif %} {%- macro remote_url() -%} {%- if remote.github -%} https://github.com/{{ remote.github.owner }}/{{ remote.github.repo }}\ {% else -%} https://github.com/bosun-ai/swiftide {%- endif -%} {% endmacro %} """ # template for the changelog body # https://keats.github.io/tera/docs/#introduction # note that the - before / after the % controls whether whitespace is rendered between each line. # Getting this right so that the markdown renders with the correct number of lines between headings # code fences and list items is pretty finicky. Note also that the 4 backticks in the commit macro # is intentional as this escapes any backticks in the commit body. # remove the leading and trailing whitespace from the template trim = false # changelog footer ================================================ FILE: renovate.json ================================================ { "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ "config:recommended" ] } ================================================ FILE: rust-toolchain.toml ================================================ [toolchain] channel = "stable" ================================================ FILE: rustfmt.toml ================================================ # docs: https://rust-lang.github.io/rustfmt/ # Unstable options - to run these, use `cargo +nightly fmt` wrap_comments = true comment_width = 100 normalize_comments = true ================================================ FILE: swiftide/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true include = [ "build.rs", "../README.md", "../images/", "src/", "../examples", "tests/", ] [badges] [dependencies] document-features = { workspace = true } # Local dependencies swiftide-core = { path = "../swiftide-core", version = "0.32" } swiftide-integrations = { path = "../swiftide-integrations", version = "0.32" } swiftide-indexing = { path = "../swiftide-indexing", version = "0.32" } swiftide-query = { path = "../swiftide-query", version = "0.32" } swiftide-agents = { path = "../swiftide-agents", version = "0.32", optional = true } swiftide-macros = { path = "../swiftide-macros", version = "0.32", optional = true } swiftide-langfuse = { path = "../swiftide-langfuse", version = "0.32", optional = true } # Re-exports for macros and ease of use anyhow.workspace = true async-trait.workspace = true serde.workspace = true serde_json.workspace = true schemars = { workspace = true, features = ["derive"] } [features] ## By default only macros are enabled default = ["macros"] macros = ["dep:swiftide-macros"] all = [ "qdrant", "redis", "tree-sitter", "openai", "fastembed", "scraping", "aws-bedrock", "groq", "ollama", "pgvector", ] #! ### Integrations ## Enables Qdrant for storage and retrieval qdrant = ["swiftide-integrations/qdrant", "swiftide-core/qdrant"] ## Enables PgVector for storage and retrieval pgvector = ["swiftide-integrations/pgvector"] ## Enables Redis as an indexing cache and storage redis = ["swiftide-integrations/redis"] ## Tree-sitter for various code transformers tree-sitter = [ "swiftide-integrations/tree-sitter", "swiftide-indexing/tree-sitter", ] ## OpenAI openai = ["swiftide-integrations/openai"] ## Groq groq = ["swiftide-integrations/groq"] ## Google Gemini gemini = ["swiftide-integrations/gemini"] ## Dashscope prompting dashscope = ["swiftide-integrations/dashscope"] ## OpenRouter prompting open-router = ["swiftide-integrations/open-router"] ## Ollama prompting ollama = ["swiftide-integrations/ollama"] # Anthropic anthropic = ["swiftide-integrations/anthropic"] ## FastEmbed (by qdrant) for fast, local, sparse and dense embeddings fastembed = ["swiftide-integrations/fastembed"] ## Scraping via spider as loader and a html to markdown transformer scraping = ["swiftide-integrations/scraping"] ## AWS Bedrock for prompting aws-bedrock = ["swiftide-integrations/aws-bedrock"] ## Lancdb for persistance and querying lancedb = ["swiftide-integrations/lancedb"] ## Fluvio loader fluvio = ["swiftide-integrations/fluvio"] ## Kafka loader kafka = ["swiftide-integrations/kafka"] ## Parquet loader parquet = ["swiftide-integrations/parquet"] ## Redb embeddable nodecache redb = ["swiftide-integrations/redb"] ## Duckdb; sqlite fork, support Persist, Retrieve and NodeCache duckdb = ["swiftide-integrations/duckdb"] #! ### Other ## MCP tool support for agents (tools only) mcp = ["swiftide-agents", "swiftide-agents/mcp"] ## Metrics for usage, pipeline and agent performance metrics = ["swiftide-integrations/metrics", "swiftide-core/metrics"] ## Various mocking and testing utilities test-utils = ["swiftide-core/test-utils"] ## Json schema for various types json-schema = ["swiftide-core/json-schema", "swiftide-agents/json-schema"] ## Estimate token counts using tiktoken tiktoken = ["swiftide-integrations/tiktoken"] #! ### Experimental ## GenAI agents and tools swiftide-agents = ["dep:swiftide-agents"] ## Langfuse tracing and observability langfuse = [ "swiftide-integrations/langfuse", "swiftide-agents/langfuse", "dep:swiftide-langfuse", ] [dev-dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } swiftide-test-utils = { path = "../swiftide-test-utils" } async-openai = { workspace = true } qdrant-client = { workspace = true, default-features = false, features = [ "serde", ] } anyhow = { workspace = true } test-log = { workspace = true } testcontainers = { workspace = true } mockall = { workspace = true } temp-dir = { workspace = true } wiremock = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } arrow-array = { workspace = true } sqlx = { workspace = true } lancedb = { workspace = true } [lints] workspace = true [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: swiftide/build.rs ================================================ use std::{fs, path::Path}; fn main() { let readme_path = Path::new("README.md"); let out_dir = std::env::var("OUT_DIR").unwrap(); let out_readme = Path::new(&out_dir).join("README.docs.md"); // Read README.md let contents = fs::read_to_string(readme_path).expect("Failed to read README.md"); // Replace ```rust with ```ignore let patched = contents.replace("```rust", "```ignore"); // Write the modified README to OUT_DIR fs::write(&out_readme, patched).expect("Failed to write patched README"); // Tell Cargo to re-run build.rs if README changes println!("cargo:rerun-if-changed=README.md"); // Export the path so we can include it in lib.rs println!("cargo:rustc-env=DOC_README={}", out_readme.display()); } ================================================ FILE: swiftide/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] #![allow(unused_imports, reason = "that is what we do here")] #![allow(clippy::doc_markdown, reason = "the readme is invalid and that is ok")] #![doc = include_str!(env!("DOC_README"))] #![doc = document_features::document_features!()] #[doc(inline)] pub use swiftide_core::prompt; #[doc(inline)] pub use swiftide_core::type_aliases::*; #[cfg(feature = "swiftide-agents")] #[doc(inline)] pub use swiftide_agents as agents; /// Common traits for common behaviour, re-exported from indexing and query pub mod traits { #[doc(inline)] pub use swiftide_core::agent_traits::*; #[doc(inline)] pub use swiftide_core::chat_completion::traits::*; #[doc(inline)] pub use swiftide_core::indexing_traits::*; #[doc(inline)] pub use swiftide_core::query_traits::*; #[doc(inline)] pub use swiftide_core::token_estimation::{Estimatable, EstimateTokens}; } #[doc(inline)] pub use swiftide_core::token_estimation::CharEstimator; /// Abstractions for chat completions and LLM interactions. #[doc(inline)] pub use swiftide_core::chat_completion; /// Integrations with various platforms and external services. pub mod integrations { #[doc(inline)] pub use swiftide_integrations::*; } /// This module serves as the main entry point for indexing in Swiftide. /// /// The indexing system in Swiftide is designed to handle the asynchronous processing of large /// volumes of data, including loading, transforming, and storing data chunks. pub mod indexing { #[doc(inline)] pub use swiftide_core::indexing::*; #[doc(inline)] pub use swiftide_indexing::*; pub mod transformers { #[cfg(feature = "tree-sitter")] #[doc(inline)] pub use swiftide_integrations::treesitter::transformers::*; pub use swiftide_indexing::transformers::*; } /// Pipeline statistics collection for monitoring and observability pub mod statistics { #[doc(inline)] pub use swiftide_core::statistics::*; } } #[cfg(feature = "macros")] #[doc(inline)] pub use swiftide_macros::*; /// # Querying pipelines /// /// Swiftide allows you to define sophisticated query pipelines. /// /// Consider the following code that uses Swiftide to load some markdown text, chunk it, embed it, /// and store it in a Qdrant index: /// /// ```no_run /// # #[cfg(all(feature = "openai", feature = "qdrant"))] /// # { /// use swiftide::{ /// indexing::{ /// self, /// loaders::FileLoader, /// transformers::{ChunkMarkdown, Embed, MetadataQAText}, /// }, /// integrations::{self, qdrant::Qdrant}, /// integrations::openai::OpenAI, /// query::{self, answers, query_transformers, response_transformers}, /// }; /// /// async fn index() -> Result<(), Box> { /// let openai_client = OpenAI::builder() /// .default_embed_model("text-embedding-3-large") /// .default_prompt_model("gpt-4o") /// .build()?; /// /// let qdrant = Qdrant::builder() /// .batch_size(50) /// .vector_size(3072) /// .collection_name("swiftide-examples") /// .build()?; /// /// indexing::Pipeline::from_loader(FileLoader::new("README.md")) /// .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) /// .then(MetadataQAText::new(openai_client.clone())) /// .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) /// .then_store_with(qdrant.clone()) /// .run() /// .await?; /// /// Ok(()) /// } /// # } /// ``` /// /// We could then define a query pipeline that uses the Qdrant index to answer questions: /// /// ```no_run /// # #[cfg(all(feature = "openai", feature = "qdrant"))] /// # { /// # use swiftide::{ /// # indexing::{ /// # self, /// # loaders::FileLoader, /// # transformers::{ChunkMarkdown, Embed, MetadataQAText}, /// # }, /// # integrations::{self, qdrant::Qdrant}, /// # query::{self, answers, query_transformers, response_transformers}, /// # integrations::openai::OpenAI, /// # }; /// # async fn query() -> Result<(), Box> { /// # let openai_client = OpenAI::builder() /// # .default_embed_model("text-embedding-3-large") /// # .default_prompt_model("gpt-4o") /// # .build()?; /// # let qdrant = Qdrant::builder() /// # .batch_size(50) /// # .vector_size(3072) /// # .collection_name("swiftide-examples") /// # .build()?; /// // By default the search strategy is SimilaritySingleEmbedding /// // which takes the latest query, embeds it, and does a similarity search /// let pipeline = query::Pipeline::default() /// .then_transform_query(query_transformers::GenerateSubquestions::from_client( /// openai_client.clone(), /// )) /// .then_transform_query(query_transformers::Embed::from_client( /// openai_client.clone(), /// )) /// .then_retrieve(qdrant.clone()) /// .then_transform_response(response_transformers::Summary::from_client( /// openai_client.clone(), /// )) /// .then_answer(answers::Simple::from_client(openai_client.clone())); /// /// let result = pipeline /// .query("What is swiftide? Please provide an elaborate explanation") /// .await?; /// /// println!("{:?}", result.answer()); /// # Ok(()) /// # } /// # } /// ``` /// /// By using a query pipeline to transform queries, we can improve the quality of the answers we get /// from our index. In this example, we used an LLM to generate subquestions, embedding those and /// then using them to search the index. Finally, we summarize the results and combine them together /// into a single answer. pub mod query { #[doc(inline)] pub use swiftide_core::querying::*; #[doc(inline)] pub use swiftide_query::*; } #[cfg(feature = "langfuse")] #[doc(inline)] pub use swiftide_langfuse as langfuse; /// Re-exports for macros #[doc(hidden)] pub mod reexports { pub use ::anyhow; pub use ::async_trait; pub use ::schemars; pub use ::serde; pub use ::serde_json; } ================================================ FILE: swiftide/src/test_utils.rs ================================================ ================================================ FILE: swiftide/tests/dyn_traits.rs ================================================ //! Tests for dyn trait objects #![cfg(all( feature = "openai", feature = "qdrant", feature = "redis", feature = "fastembed", feature = "tree-sitter" ))] use swiftide::{indexing::transformers::ChunkCode, integrations}; use swiftide_core::{ BatchableTransformer, ChunkerTransformer, EmbeddingModel, Loader, NodeCache, Persist, SimplePrompt, Transformer, }; use swiftide_indexing::{loaders, transformers}; use swiftide_integrations::fastembed::FastEmbed; #[test_log::test(tokio::test)] async fn test_name_on_dyn() { let fastembed: Box = Box::new(FastEmbed::try_default().unwrap()); assert_eq!(fastembed.name(), "FastEmbed"); let chunk_code: Box> = Box::new(ChunkCode::try_for_language("rust").unwrap()); assert_eq!(chunk_code.name(), "ChunkCode"); let transformer: Box> = Box::new(transformers::MetadataQAText::default()); assert_eq!(transformer.name(), "MetadataQAText"); let redis: Box> = Box::new( integrations::redis::Redis::try_from_url("redis://localhost:6379", "prefix").unwrap(), ); assert_eq!(redis.name(), "Redis"); let embed: Box> = Box::new(transformers::Embed::new(fastembed).with_batch_size(10)); assert_eq!(embed.name(), "Embed"); let qdrant: Box> = Box::new( integrations::qdrant::Qdrant::try_from_url("http://localhost:6333") .unwrap() .vector_size(1536) .build() .unwrap(), ); assert_eq!(qdrant.name(), "Qdrant"); let openai_client: Box = Box::new( integrations::openai::OpenAI::builder() .default_embed_model("text-embedding-3-small") .default_prompt_model("gpt-3.5-turbo") .build() .unwrap(), ); assert_eq!(openai_client.name(), "GenericOpenAI"); let loader: Box> = Box::new(loaders::FileLoader::new(".").with_extensions(&["rs"])); assert_eq!(loader.name(), "FileLoader"); } ================================================ FILE: swiftide/tests/indexing_pipeline.rs ================================================ //! This module contains tests for the indexing pipeline in the Swiftide project. //! The tests validate the functionality of the pipeline, ensuring it processes data correctly //! from a temporary file, simulates API responses, and stores data accurately in the Qdrant vector //! database. #![cfg(all( feature = "openai", feature = "qdrant", feature = "redis", feature = "tree-sitter" ))] use qdrant_client::qdrant::vectors_output::VectorsOptions; use qdrant_client::qdrant::{ScrollPointsBuilder, SearchPointsBuilder, Value}; use swiftide::indexing::*; use swiftide::integrations; use swiftide_test_utils::*; use temp_dir::TempDir; use wiremock::MockServer; /// Tests the indexing pipeline without any mocks. /// /// This test sets up a temporary directory and file, simulates API responses using mock servers, /// configures an `OpenAI` client, and runs the indexing pipeline. It then validates that the data /// is correctly stored in the Qdrant vector database. /// /// # Panics /// Panics if any of the setup steps fail, such as creating the temporary directory or file, /// starting the mock server, or configuring the `OpenAI` client. /// /// # Errors /// If the indexing pipeline encounters an error, the test will print the received requests /// for debugging purposes. #[test_log::test(tokio::test)] async fn test_indexing_pipeline() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); std::fs::write(&codefile, "fn main() { println!(\"Hello, World!\"); }").unwrap(); // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; mock_embeddings(&mock_server, 1).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let (_redis, redis_url) = start_redis().await; let (qdrant_container, qdrant_url) = start_qdrant().await; // Coverage CI runs in container, just accept the double qdrant and use the service instead let qdrant_url = std::env::var("QDRANT_URL").unwrap_or(qdrant_url); println!("Qdrant URL: {qdrant_url}"); let result = Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .with_default_llm_client(openai_client.clone()) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then(transformers::MetadataQACode::default()) .filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap()) .then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(1)) .log_nodes() .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() .vector_size(1536) .collection_name("swiftide-test".to_string()) .build() .unwrap(), ) .run() .await; if result.is_err() { println!("\n Received the following requests: \n"); // Just some serde magic to pretty print requests on failure let received_requests = mock_server .received_requests() .await .unwrap_or_default() .into_iter() .map(|req| { format!( "- {} {}\n{}", req.method, req.url, serde_json::to_string_pretty( &serde_json::from_slice::(&req.body).unwrap() ) .unwrap() ) }) .collect::>() .join("\n---\n"); println!("{received_requests}"); } result.expect("Indexing pipeline failed"); let qdrant_client = qdrant_client::Qdrant::from_url(&qdrant_url) .build() .unwrap(); let stored_node = qdrant_client .scroll( ScrollPointsBuilder::new("swiftide-test") .limit(1) .with_payload(true) .with_vectors(true), ) .await .unwrap(); dbg!( std::str::from_utf8(&qdrant_container.stdout_to_vec().await.unwrap()) .unwrap() .split('\n') .collect::>() ); dbg!( std::str::from_utf8(&qdrant_container.stderr_to_vec().await.unwrap()) .unwrap() .split('\n') .collect::>() ); dbg!(stored_node); let search_request = SearchPointsBuilder::new("swiftide-test", vec![0_f32; 1536], 10).with_payload(true); let search_response = qdrant_client.search_points(search_request).await.unwrap(); dbg!(&search_response); let first = search_response.result.first().unwrap(); dbg!(first); assert!( first .payload .get("path") .unwrap() .as_str() .unwrap() .ends_with("main.rs") ); assert_eq!( first.payload.get("content").unwrap().as_str().unwrap(), "fn main() { println!(\"Hello, World!\"); }" ); assert_eq!( first .payload .get("Questions and Answers (code)") .unwrap() .as_str() .unwrap(), "\n\nHello there, how may I assist you today?" ); } #[test_log::test(tokio::test)] async fn test_named_vectors() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); std::fs::write(&codefile, "fn main() { println!(\"Hello, World!\"); }").unwrap(); // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; mock_embeddings(&mock_server, 2).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let (_redis, redis_url) = start_redis().await; let (_qdrant, qdrant_url) = start_qdrant().await; // Coverage CI runs in container, just accept the double qdrant and use the service instead let qdrant_url = std::env::var("QDRANT_URL").unwrap_or(qdrant_url); println!("Qdrant URL: {qdrant_url}"); let result = Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .with_embed_mode(EmbedMode::PerField) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then(transformers::MetadataQACode::new(openai_client.clone())) .filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap()) .then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() .vector_size(1536) .collection_name("named-vectors-test".to_string()) .with_vector(EmbeddedField::Chunk) .with_vector(EmbeddedField::Metadata( transformers::metadata_qa_code::NAME.into(), )) .build() .unwrap(), ) .run() .await; result.expect("Named vectors test indexing pipeline failed"); let qdrant_client = qdrant_client::Qdrant::from_url(&qdrant_url) .build() .unwrap(); let search_request = SearchPointsBuilder::new("named-vectors-test", vec![0_f32; 1536], 10) .vector_name( EmbeddedField::Metadata(transformers::metadata_qa_code::NAME.into()).to_string(), ) .with_payload(true) .with_vectors(true); let search_response = qdrant_client.search_points(search_request).await.unwrap(); let first = search_response.result.into_iter().next().unwrap(); assert!( first .payload .get("path") .unwrap() .as_str() .unwrap() .ends_with("main.rs") ); assert_eq!( first.payload.get("content").unwrap().as_str().unwrap(), "fn main() { println!(\"Hello, World!\"); }" ); assert_eq!( first .payload .get("Questions and Answers (code)") .unwrap() .as_str() .unwrap(), "\n\nHello there, how may I assist you today?" ); let vectors = first.vectors.expect("Response has vectors"); let VectorsOptions::Vectors(named_vectors) = vectors .vectors_options .expect("Response has vector options") else { panic!("Expected named vectors"); }; let vectors = named_vectors.vectors; assert_eq!(vectors.len(), 2); assert!(vectors.contains_key(&EmbeddedField::Chunk.to_string())); assert!(vectors.contains_key( &EmbeddedField::Metadata(transformers::metadata_qa_code::NAME.into()).to_string() )); } ================================================ FILE: swiftide/tests/lancedb.rs ================================================ #![cfg(all( feature = "openai", feature = "lancedb", feature = "fastembed", feature = "tree-sitter" ))] use anyhow::Context; use lancedb::query::{self as lance_query_builder, QueryBase}; use swiftide::indexing::{self, TextNode}; use swiftide::indexing::{ EmbeddedField, transformers::{ChunkCode, MetadataQACode, metadata_qa_code::NAME as METADATA_QA_CODE_NAME}, }; use swiftide::query::{self as swift_query_pipeline, Query, states}; use swiftide_indexing::{Pipeline, loaders, transformers}; use swiftide_integrations::{ fastembed::FastEmbed, lancedb::{self as lance_integration, LanceDB}, }; use swiftide_query::{answers, query_transformers, response_transformers}; use swiftide_test_utils::{mock_chat_completions, openai_client}; use temp_dir::TempDir; use wiremock::MockServer; #[test_log::test(tokio::test)] async fn test_lancedb() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); let code = "fn main() { println!(\"Hello, World!\"); }"; std::fs::write(&codefile, code).unwrap(); // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let fastembed = FastEmbed::try_default().unwrap(); let lancedb = LanceDB::builder() .uri(tempdir.child("lancedb").to_str().unwrap()) .vector_size(384) .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_CODE_NAME) .with_metadata("filter") .with_metadata("path") .table_name("swiftide_test") .build() .unwrap(); Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(ChunkCode::try_for_language("rust").unwrap()) .then(MetadataQACode::new(openai_client.clone())) .then(|mut node: TextNode| { // Add path to metadata, by default, storage will store all metadata fields node.metadata .insert("path", node.path.display().to_string()); node.metadata.insert("filter", "true"); Ok(node) }) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) .log_nodes() .then_store_with(lancedb.clone()) .run() .await .unwrap(); let strategy = swift_query_pipeline::search_strategies::SimilaritySingleEmbedding::from_filter( "filter = \"true\"".to_string(), ); let query_pipeline = swift_query_pipeline::Pipeline::from_search_strategy(strategy) .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_retrieve(lancedb.clone()) .then_transform_response(response_transformers::Summary::from_client( openai_client.clone(), )) .then_answer(answers::Simple::from_client(openai_client.clone())); let result: Query = query_pipeline.query("What is swiftide?").await.unwrap(); dbg!(&result); assert_eq!( result.answer(), "\n\nHello there, how may I assist you today?" ); let retrieved_document = result.documents().first().unwrap(); assert_eq!(retrieved_document.content(), code); assert_eq!( retrieved_document.metadata().get("path").unwrap(), codefile.to_str().unwrap() ); } #[test_log::test(tokio::test)] async fn test_lancedb_retrieve_dynamic_search() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); let code = "fn main() { println!(\"Hello, World!\"); }"; std::fs::write(&codefile, code).unwrap(); // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let fastembed = FastEmbed::try_default().unwrap(); let lancedb = LanceDB::builder() .uri(tempdir.child("lancedb").to_str().unwrap()) .vector_size(384) .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_CODE_NAME) .with_metadata("filter") .with_metadata("path") .table_name("swiftide_test") .build() .unwrap(); Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(ChunkCode::try_for_language("rust").unwrap()) .then(MetadataQACode::new(openai_client.clone())) .then(|mut node: indexing::TextNode| { // Add path to metadata, by default, storage will store all metadata fields node.metadata .insert("path", node.path.display().to_string()); node.metadata .insert("filter".to_string(), "true".to_string()); Ok(node) }) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) .log_nodes() .then_store_with(lancedb.clone()) .run() .await .unwrap(); // Create the custom query strategy for vector similarity search let create_vector_search_strategy = |lancedb: &LanceDB, table_name: String| -> swift_query_pipeline::search_strategies::CustomStrategy< lance_query_builder::VectorQuery, > { let table_name = table_name.clone(); let lancedb = lancedb.clone(); swift_query_pipeline::search_strategies::CustomStrategy::from_async_query( move |query_node| { // Create owned copies for the async block let table_name = table_name.clone(); let lancedb = lancedb.clone(); let embedding = if let Some(embedding) = &query_node.embedding { embedding.clone() } else { panic!("Query embedding not found"); }; // Return a Future using async block syntax Box::pin(async move { // Create a new connection for each query execution let connection = lancedb.get_connection().await?; // Open the table within the query execution context let vector_table = connection .open_table(&table_name) .execute() .await .context("Failed to open vector search table")?; let vector_field = lance_integration::VectorConfig::from(EmbeddedField::Combined) .field_name(); // Build and return the query let query_builder = vector_table .query() .nearest_to(embedding.as_slice())? .column(&vector_field) .limit(20); Ok(query_builder) // Connection is dropped here when query_builder is executed }) }, ) }; let vector_search_strategy = create_vector_search_strategy(&lancedb, "swiftide_test".to_string()); let query_pipeline = swift_query_pipeline::Pipeline::from_search_strategy(vector_search_strategy) .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_retrieve(lancedb.clone()) .then_transform_response(response_transformers::Summary::from_client( openai_client.clone(), )) .then_answer(answers::Simple::from_client(openai_client.clone())); let result: Query = query_pipeline.query("What is swiftide?").await.unwrap(); dbg!(&result); assert_eq!( result.answer(), "\n\nHello there, how may I assist you today?" ); let retrieved_document = result.documents().first().unwrap(); assert_eq!(retrieved_document.content(), code); assert_eq!( retrieved_document.metadata().get("path").unwrap(), codefile.to_str().unwrap() ); } ================================================ FILE: swiftide/tests/pgvector.rs ================================================ //! This module contains tests for the `PgVector` indexing pipeline in the Swiftide project. //! The tests validate the functionality of the pipeline, ensuring that data is correctly indexed //! and processed from temporary files, database configurations, and simulated environments. #![cfg(all( feature = "openai", feature = "pgvector", feature = "fastembed", feature = "tree-sitter" ))] use swiftide_core::document::Document; use swiftide_integrations::treesitter::metadata_qa_code; use temp_dir::TempDir; use anyhow::{Result, anyhow}; use sqlx::{prelude::FromRow, types::Uuid}; use swiftide::{ indexing::{ self, EmbeddedField, Pipeline, loaders, transformers::{ self, ChunkCode, MetadataQACode, metadata_qa_code::NAME as METADATA_QA_CODE_NAME, }, }, integrations::{ self, pgvector::{FieldConfig, PgVector, PgVectorBuilder, VectorConfig}, }, query::{self, Query, answers, query_transformers, response_transformers, states}, }; use swiftide_test_utils::{mock_chat_completions, openai_client}; use wiremock::MockServer; #[allow(dead_code)] #[derive(Debug, Clone, FromRow)] struct VectorSearchResult { id: Uuid, chunk: String, } /// Test case for verifying the `PgVector` indexing pipeline functionality. /// /// This test: /// - Sets up a temporary file and Postgres database for testing. /// - Configures a `PgVector` instance with a vector size of 384. /// - Executes an indexing pipeline for Rust code chunks with embedded vector metadata. /// - Performs a similarity-based vector search on the database and validates the retrieved results. /// /// Ensures correctness of end-to-end data flow, including table management, vector storage, and /// query execution. #[test_log::test(tokio::test)] async fn test_pgvector_indexing() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); let code = "fn main() { println!(\"Hello, World!\"); }"; std::fs::write(&codefile, code).unwrap(); let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; // Configure Pgvector with a default vector size, a single embedding // and in addition to embedding the text metadata, also store it in a field let pgv_storage = PgVector::builder() .db_url(pgv_db_url) .vector_size(384) .with_vector(EmbeddedField::Combined) .table_name("swiftide_test") .build() .unwrap(); // Drop the existing test table before running the test println!("Dropping existing test table & index if it exists"); let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test"; let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx"; if let Ok(pool) = pgv_storage.get_pool().await { sqlx::query(drop_table_sql) .execute(pool) .await .expect("Failed to execute SQL query for dropping the table"); sqlx::query(drop_index_sql) .execute(pool) .await .expect("Failed to execute SQL query for dropping the index"); } else { panic!("Unable to acquire database connection pool"); } let result = Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(ChunkCode::try_for_language("rust").unwrap()) .then(|mut node: indexing::TextNode| { node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); Ok(node) }) .then_store_with(pgv_storage.clone()) .run() .await; result.expect("PgVector Named vectors test indexing pipeline failed"); let pool = pgv_storage .get_pool() .await .expect("Unable to acquire database connection pool"); // Start building the SQL query let sql_vector_query = "SELECT id, chunk FROM swiftide_test ORDER BY vector_combined <=> $1::VECTOR LIMIT $2"; println!("Running retrieve with SQL: {sql_vector_query}"); let top_k: i32 = 10; let embedding = vec![1.0; 384]; let data: Vec = sqlx::query_as(sql_vector_query) .bind(embedding) .bind(top_k) .fetch_all(pool) .await .expect("Sql named vector query failed"); let docs: Vec<_> = data.into_iter().map(|r| r.chunk).collect(); println!("Retrieved documents for debugging: {docs:#?}"); assert_eq!(docs[0], "fn main() { println!(\"Hello, World!\"); }"); } /// Test the retrieval functionality of `PgVector` integration. /// /// This test verifies that a Rust code snippet can be embedded, /// stored in a `PostgreSQL` database using `PgVector`, and accurately /// retrieved using a single similarity-based query pipeline. It sets up /// a mock `OpenAI` client, configures `PgVector`, and executes a query /// to ensure the pipeline retrieves the correct data and generates /// an expected response. #[test_log::test(tokio::test)] async fn test_pgvector_retrieve() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); let code = "fn main() { println!(\"Hello, World!\"); }"; std::fs::write(&codefile, code).unwrap(); let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let fastembed = integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed"); // Configure Pgvector with a default vector size, a single embedding // and in addition to embedding the text metadata, also store it in a field let pgv_storage = PgVector::builder() .db_url(pgv_db_url) .vector_size(384) .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_CODE_NAME) .with_metadata("filter") .table_name("swiftide_test") .build() .unwrap(); // Drop the existing test table before running the test println!("Dropping existing test table & index if it exists"); let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test"; let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx"; if let Ok(pool) = pgv_storage.get_pool().await { sqlx::query(drop_table_sql) .execute(pool) .await .expect("Failed to execute SQL query for dropping the table"); sqlx::query(drop_index_sql) .execute(pool) .await .expect("Failed to execute SQL query for dropping the index"); } else { panic!("Unable to acquire database connection pool"); } Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(ChunkCode::try_for_language("rust").unwrap()) .then(MetadataQACode::new(openai_client.clone())) .then(|mut node: indexing::TextNode| { node.metadata .insert("filter".to_string(), "true".to_string()); Ok(node) }) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) .log_nodes() .then_store_with(pgv_storage.clone()) .run() .await .unwrap(); let strategy = query::search_strategies::SimilaritySingleEmbedding::from_filter( "filter = \"true\"".to_string(), ); let query_pipeline = query::Pipeline::from_search_strategy(strategy) .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_retrieve(pgv_storage.clone()) .then_transform_response(response_transformers::Summary::from_client( openai_client.clone(), )) .then_answer(answers::Simple::from_client(openai_client.clone())); let result: Query = query_pipeline.query("What is swiftide?").await.unwrap(); assert_eq!( result.answer(), "\n\nHello there, how may I assist you today?" ); let first_document = result.documents().first().unwrap(); let expected = Document::builder() .content("fn main() { println!(\"Hello, World!\"); }") .metadata([ ( metadata_qa_code::NAME, "\n\nHello there, how may I assist you today?", ), ("filter", "true"), ]) .build() .unwrap(); assert_eq!(first_document, &expected); } /// Tests the dynamic vector similarity search functionality using `PostgreSQL`. /// /// This integration test verifies the complete workflow of vector similarity search: /// 1. Sets up a temporary test environment with a sample Rust code file /// 2. Configures `PostgreSQL` with pgvector extension for vector operations /// 3. Creates and populates test data using a processing pipeline: /// - Loads source code files /// - Chunks code into processable segments /// - Generates metadata using `OpenAI` /// - Embeds text using `FastEmbed` /// - Stores processed data in `PostgreSQL` /// 4. Implements a custom search strategy that: /// - Filters results based on metadata /// - Orders results by vector similarity /// - Limits the number of returned results /// 5. Executes a query pipeline that: /// - Generates and embeds the search query /// - Retrieves similar documents /// - Transforms results into a meaningful summary /// - Produces a final answer /// /// # Configuration Pattern /// The test demonstrates the recommended configuration approach: /// - Define search parameters as constants in the implementation scope /// - Pass configuration through the query generator closure /// - Keep the strategy struct minimal and focused on query generation #[test_log::test(tokio::test)] async fn test_pgvector_retrieve_dynamic_search() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); let code = "fn main() { println!(\"Hello, World!\"); }"; std::fs::write(&codefile, code).unwrap(); let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let fastembed = integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed"); // Configure Pgvector with a default vector size, a single embedding // and in addition to embedding the text metadata, also store it in a field let pgv_storage = PgVector::builder() .db_url(pgv_db_url) .vector_size(384) .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_CODE_NAME) .with_metadata("filter") .table_name("swiftide_test") .build() .unwrap(); // Drop the existing test table before running the test println!("Dropping existing test table & index if it exists"); let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test"; let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx"; if let Ok(pool) = pgv_storage.get_pool().await { sqlx::query(drop_table_sql) .execute(pool) .await .expect("Failed to execute SQL query for dropping the table"); sqlx::query(drop_index_sql) .execute(pool) .await .expect("Failed to execute SQL query for dropping the index"); } else { panic!("Unable to acquire database connection pool"); } Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(ChunkCode::try_for_language("rust").unwrap()) .then(MetadataQACode::new(openai_client.clone())) .then(|mut node: indexing::TextNode| { node.metadata .insert("filter".to_string(), "true".to_string()); Ok(node) }) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) .log_nodes() .then_store_with(pgv_storage.clone()) .run() .await .unwrap(); // First, we'll clone pgv_storage before using it in the closure let pgv_storage_for_closure = pgv_storage.clone(); // Configure search strategy // Create a custom query generator with metadata filtering let custom_strategy = query::search_strategies::CustomStrategy::from_query( move |query_node| -> Result> { const CUSTOM_STRATEGY_MAX_RESULTS: i64 = 5; let mut builder = sqlx::QueryBuilder::new(""); let table: &str = pgv_storage_for_closure.get_table_name(); // Get column definitions let default_fields: Vec<_> = PgVectorBuilder::default_fields(); let default_columns: Vec<&str> = default_fields.iter().map(FieldConfig::field_name).collect(); // Start building the query properly builder.push("SELECT "); builder.push(default_columns.join(", ")); builder.push(" FROM "); builder.push(table); // Add metadata filter builder.push(" WHERE meta_"); builder.push(PgVector::normalize_field_name("filter")); builder.push(" @> "); builder.push("'{\"filter\": \"true\"}'::jsonb"); // Add vector similarity ordering let vector_field = VectorConfig::from(EmbeddedField::Combined).field; builder.push(" ORDER BY "); builder.push(vector_field); builder.push(" <=> "); // Let QueryBuilder handle the parameter placeholders builder.push_bind( query_node .embedding .as_ref() .ok_or_else(|| anyhow!("Missing embedding in query state"))? .clone(), ); builder.push("::vector"); // Add LIMIT clause builder.push(" LIMIT "); builder.push_bind(CUSTOM_STRATEGY_MAX_RESULTS); Ok(builder) }, ); let query_pipeline = query::Pipeline::from_search_strategy(custom_strategy) .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_retrieve(pgv_storage.clone()) .then_transform_response(response_transformers::Summary::from_client( openai_client.clone(), )) .then_answer(answers::Simple::from_client(openai_client.clone())); let result: Query = query_pipeline.query("What is swiftide?").await.unwrap(); assert_eq!( result.answer(), "\n\nHello there, how may I assist you today?" ); let first_document = result.documents().first().unwrap(); // The custom query explicitly skipped metadata let expected = Document::builder() .content("fn main() { println!(\"Hello, World!\"); }") .build() .unwrap(); assert_eq!(first_document, &expected); } ================================================ FILE: swiftide/tests/query_pipeline.rs ================================================ #![cfg(all( feature = "openai", feature = "qdrant", feature = "fastembed", feature = "tree-sitter" ))] use swiftide::indexing::{self, *}; use swiftide::query::search_strategies::HybridSearch; use swiftide::{integrations, query}; use swiftide_integrations::fastembed::FastEmbed; use swiftide_query::{answers, query_transformers, response_transformers}; use swiftide_test_utils::*; use temp_dir::TempDir; use wiremock::MockServer; #[test_log::test(tokio::test)] async fn test_query_pipeline() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); std::fs::write(&codefile, "fn main() { println!(\"Hello, World!\"); }").unwrap(); // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let (_qdrant, qdrant_url) = start_qdrant().await; let qdrant_client = integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() .vector_size(384) .collection_name("swiftide-test".to_string()) .build() .unwrap(); let fastembed = integrations::fastembed::FastEmbed::try_default().unwrap(); println!("Qdrant URL: {qdrant_url}"); indexing::Pipeline::from_loader( loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"]), ) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(1)) .then_store_with(qdrant_client.clone()) .run() .await .unwrap(); let query_pipeline = query::Pipeline::default() .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_retrieve(qdrant_client.clone()) .then_transform_response(response_transformers::Summary::from_client( openai_client.clone(), )) .then_answer(answers::Simple::from_client(openai_client.clone())); let result = query_pipeline.query("What is swiftide?").await.unwrap(); assert!(result.embedding.is_some()); assert!(!result.answer().is_empty()); } #[test_log::test(tokio::test)] async fn test_hybrid_search_qdrant() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); std::fs::write(&codefile, "fn main() { println!(\"Hello, World!\"); }").unwrap(); // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_chat_completions(&mock_server).await; let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); let (_qdrant, qdrant_url) = start_qdrant().await; let batch_size = 10; let qdrant_client = integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() .vector_size(384) .batch_size(batch_size) .with_vector(EmbeddedField::Combined) .with_sparse_vector(EmbeddedField::Combined) .collection_name("swiftide-hybrid") .build() .unwrap(); let fastembed_sparse = FastEmbed::try_default_sparse().unwrap().clone(); let fastembed = FastEmbed::try_default().unwrap().clone(); println!("Qdrant URL: {qdrant_url}"); indexing::Pipeline::from_loader( loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"]), ) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(batch_size)) .then_in_batch( transformers::SparseEmbed::new(fastembed_sparse.clone()).with_batch_size(batch_size), ) .then_store_with(qdrant_client.clone()) .run() .await .unwrap(); let collection = qdrant_client .client() .collection_info("swiftide-hybrid") .await .unwrap(); dbg!(collection); let query_pipeline = query::Pipeline::from_search_strategy(HybridSearch::default()) .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) .then_transform_query(query_transformers::SparseEmbed::from_client( fastembed_sparse.clone(), )) .then_retrieve(qdrant_client.clone()) .then_answer(answers::Simple::from_client(openai_client.clone())); let result = query_pipeline.query("What is swiftide?").await.unwrap(); assert!(result.embedding.is_some()); assert!(!result.answer().is_empty()); } ================================================ FILE: swiftide/tests/sparse_embeddings_and_hybrid_search.rs ================================================ //! This module contains tests for the indexing pipeline in the Swiftide project. //! The tests validate the functionality of the pipeline, ensuring it processes data correctly //! from a temporary file, simulates API responses, and stores data accurately in the Qdrant vector //! database. #![cfg(all(feature = "qdrant", feature = "fastembed", feature = "tree-sitter"))] use qdrant_client::qdrant::{ Fusion, PrefetchQueryBuilder, Query, QueryPointsBuilder, ScrollPointsBuilder, SearchPointsBuilder, VectorInput, }; use swiftide::indexing::*; use swiftide::integrations; use swiftide_integrations::fastembed::FastEmbed; use swiftide_test_utils::*; use temp_dir::TempDir; use wiremock::MockServer; /// Tests the indexing pipeline without any mocks. /// /// This test sets up a temporary directory and file, simulates API responses using mock servers, /// configures an `OpenAI` client, and runs the indexing pipeline. It then validates that the data /// is correctly stored in the Qdrant vector database. /// /// # Panics /// Panics if any of the setup steps fail, such as creating the temporary directory or file, /// starting the mock server, or configuring the `OpenAI` client. /// /// # Errors /// If the indexing pipeline encounters an error, the test will print the received requests /// for debugging purposes. #[test_log::test(tokio::test)] async fn test_sparse_indexing_pipeline() { // Setup temporary directory and file for testing let tempdir = TempDir::new().unwrap(); let codefile = tempdir.child("main.rs"); std::fs::write(&codefile, "fn main() { println!(\"Hello, World!\"); }").unwrap(); // Setup mock servers to simulate API responses let mock_server = MockServer::start().await; mock_embeddings(&mock_server, 1).await; let (qdrant_container, qdrant_url) = start_qdrant().await; let fastembed_sparse = FastEmbed::try_default_sparse().unwrap(); let fastembed = FastEmbed::try_default().unwrap(); let memory_storage = persist::MemoryStorage::default(); println!("Qdrant URL: {qdrant_url}"); let result = Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then_in_batch(transformers::SparseEmbed::new(fastembed_sparse).with_batch_size(20)) .then_in_batch(transformers::Embed::new(fastembed).with_batch_size(20)) .log_nodes() .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() .vector_size(384) .with_vector(EmbeddedField::Combined) .with_sparse_vector(EmbeddedField::Combined) .collection_name("swiftide-test".to_string()) .build() .unwrap(), ) .then_store_with(memory_storage.clone()) .run() .await; let node = memory_storage .get_all_values() .await .first() .unwrap() .clone(); result.expect("Indexing pipeline failed"); let qdrant_client = qdrant_client::Qdrant::from_url(&qdrant_url) .build() .unwrap(); let stored_node = qdrant_client .scroll( ScrollPointsBuilder::new("swiftide-test") .limit(1) .with_payload(true) .with_vectors(true), ) .await .unwrap(); dbg!(stored_node); dbg!( std::str::from_utf8(&qdrant_container.stdout_to_vec().await.unwrap()) .unwrap() .split('\n') .collect::>() ); // Search using the dense vector let dense = node .vectors .unwrap() .into_values() .collect::>() .first() .cloned() .unwrap(); let search_request = SearchPointsBuilder::new("swiftide-test", dense.as_slice(), 10) .with_payload(true) .vector_name(EmbeddedField::Combined); let search_response = qdrant_client.search_points(search_request).await.unwrap(); let first = search_response.result.first().unwrap(); assert!( first .payload .get("path") .unwrap() .as_str() .unwrap() .ends_with("main.rs") ); assert_eq!( first.payload.get("content").unwrap().as_str().unwrap(), "fn main() { println!(\"Hello, World!\"); }" ); // Search using the sparse vector let sparse = node .sparse_vectors .unwrap() .into_values() .collect::>() .first() .cloned() .unwrap(); // Search sparse let search_request = SearchPointsBuilder::new("swiftide-test", sparse.values.as_slice(), 10) .sparse_indices(sparse.indices.clone()) .vector_name(format!("{}_sparse", EmbeddedField::Combined)) .with_payload(true); let search_response = qdrant_client.search_points(search_request).await.unwrap(); let first = search_response.result.first().unwrap(); assert!( first .payload .get("path") .unwrap() .as_str() .unwrap() .ends_with("main.rs") ); assert_eq!( first.payload.get("content").unwrap().as_str().unwrap(), "fn main() { println!(\"Hello, World!\"); }" ); // Search hybrid let search_response = qdrant_client .query( QueryPointsBuilder::new("swiftide-test") .with_payload(true) .add_prefetch( PrefetchQueryBuilder::default() .query(Query::new_nearest(VectorInput::new_sparse( sparse.indices, sparse.values, ))) .using("Combined_sparse") .limit(20u64), ) .add_prefetch( PrefetchQueryBuilder::default() .query(Query::new_nearest(dense)) .using("Combined") .limit(20u64), ) .query(Query::new_fusion(Fusion::Rrf)), ) .await .unwrap(); let first = search_response.result.first().unwrap(); assert!( first .payload .get("path") .unwrap() .as_str() .unwrap() .ends_with("main.rs") ); assert_eq!( first.payload.get("content").unwrap().as_str().unwrap(), "fn main() { println!(\"Hello, World!\"); }" ); } ================================================ FILE: swiftide-agents/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-agents" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] swiftide-core = { path = "../swiftide-core", version = "0.32" } swiftide-indexing = { path = "../swiftide-indexing", version = "0.32" } anyhow.workspace = true async-trait.workspace = true dyn-clone.workspace = true derive_builder.workspace = true indoc.workspace = true tracing.workspace = true tokio.workspace = true # pretty_assertions.workspace = true strum.workspace = true strum_macros.workspace = true serde.workspace = true serde_json.workspace = true fs-err = { workspace = true, features = ["tokio"] } thiserror.workspace = true futures-util.workspace = true tokio-stream.workspace = true tokio-util = { workspace = true, features = ["rt"] } convert_case.workspace = true schemars = { workspace = true, features = ["derive"] } # Mcp rmcp = { workspace = true, optional = true, default-features = false, features = [ "base64", "client", "macros", "server", ] } [dev-dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } mockall.workspace = true test-log.workspace = true temp-dir.workspace = true insta.workspace = true rmcp = { workspace = true, features = ["server"] } schemars = { workspace = true } [lints] workspace = true [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] [features] mcp = ["dep:rmcp"] json-schema = ["swiftide-core/json-schema"] langfuse = [] ================================================ FILE: swiftide-agents/src/agent.rs ================================================ #![allow(dead_code)] use crate::{ default_context::DefaultContext, errors::AgentError, hooks::{ AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn, Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn, }, invoke_hooks, state::{self, StopReason}, system_prompt::SystemPrompt, tools::{arg_preprocessor::ArgPreprocessor, control::Stop}, }; use std::{ collections::{HashMap, HashSet, VecDeque}, hash::{DefaultHasher, Hash as _, Hasher as _}, sync::Arc, }; use derive_builder::Builder; use futures_util::stream::StreamExt; use swiftide_core::{ AgentContext, ToolBox, chat_completion::{ ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput, }, prompt::Prompt, }; use tracing::{Instrument, debug}; /// Agents are the main interface for building agentic systems. /// /// Construct agents by calling the builder, setting an llm, configure hooks, tools and other /// customizations. /// /// # Important defaults /// /// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`. /// - A default `stop` tool is provided for agents to explicitly stop if needed /// - The default `SystemPrompt` instructs the agent with chain of thought and some common /// safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient. /// /// Agents are *not* cheap to clone. However, if an agent gets cloned, it will operate on the /// same context. #[derive(Builder)] pub struct Agent { /// Hooks are functions that are called at specific points in the agent's lifecycle. #[builder(default, setter(into))] pub(crate) hooks: Vec, /// The context in which the agent operates, by default this is the `DefaultContext`. #[builder( setter(custom), default = Arc::new(DefaultContext::default()) as Arc )] pub(crate) context: Arc, /// Tools the agent can use #[builder(default = Agent::default_tools(), setter(custom))] pub(crate) tools: HashSet>, /// Toolboxes are collections of tools that can be added to the agent. /// /// Toolboxes make their tools available to the agent at runtime. #[builder(default)] pub(crate) toolboxes: Vec>, /// The language model that the agent uses for completion. #[builder(setter(custom))] pub(crate) llm: Box, /// System prompt for the agent when it starts /// /// Some agents profit significantly from a tailored prompt. But it is not always needed. /// /// See [`SystemPrompt`] for an opiniated, customizable system prompt. /// /// Swiftide provides a default system prompt for all agents. /// /// Alternatively you can also provide a `Prompt` directly, or disable the system prompt. /// /// # Example /// /// ```no_run /// # use swiftide_agents::system_prompt::SystemPrompt; /// # use swiftide_agents::Agent; /// Agent::builder() /// .system_prompt( /// SystemPrompt::builder().role("You are an expert engineer") /// .build().unwrap()) /// .build().unwrap(); /// ``` #[builder(setter(into, strip_option), default = Some(SystemPrompt::default()))] pub(crate) system_prompt: Option, /// Initial state of the agent #[builder(private, default = state::State::default())] pub(crate) state: state::State, /// Optional limit on the amount of loops the agent can run. /// The counter is reset when the agent is stopped. #[builder(default, setter(strip_option))] pub(crate) limit: Option, /// The maximum amount of times the failed output of a tool will be send /// to an LLM before the agent stops. Defaults to 3. /// /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be /// worth while. If the limit is not reached, the agent will send the formatted error back to /// the LLM. /// /// The limit is hashed based on the tool call name and arguments, so the limit is per tool /// call. /// /// This limit is _not_ reset when the agent is stopped. #[builder(default = 3)] pub(crate) tool_retry_limit: usize, /// Enables streaming the chat completion responses for the agent. #[builder(default)] pub(crate) streaming: bool, /// When set to true, any tools in `Agent::default_tools` will be omitted. Only works if you /// at at least one tool of your own. #[builder(private, default)] pub(crate) clear_default_tools: bool, /// Internally tracks the amount of times a tool has been retried. The key is a hash based on /// the name and args of the tool. #[builder(private, default)] pub(crate) tool_retries_counter: HashMap, /// The name of the agent; optional #[builder(default = "unnamed_agent".into(), setter(into))] pub(crate) name: String, /// User messages waiting for any pending tool calls to complete. #[builder(private, default)] pub(crate) pending_user_messages: VecDeque, } impl Clone for Agent { fn clone(&self) -> Self { Agent { hooks: self.hooks.clone(), context: Arc::new(self.context.clone()), tools: self.tools.clone(), toolboxes: self.toolboxes.clone(), llm: self.llm.clone(), system_prompt: self.system_prompt.clone(), state: self.state.clone(), limit: self.limit, tool_retry_limit: self.tool_retry_limit, tool_retries_counter: HashMap::new(), streaming: self.streaming, name: self.name.clone(), clear_default_tools: self.clear_default_tools, pending_user_messages: VecDeque::new(), } } } impl std::fmt::Debug for Agent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Agent") .field("name", &self.name) // display hooks as a list of type: number of hooks .field( "hooks", &self .hooks .iter() .map(std::string::ToString::to_string) .collect::>(), ) .field( "tools", &self .tools .iter() .map(swiftide_core::Tool::name) .collect::>(), ) .field("llm", &"Box") .field("state", &self.state) .finish() } } impl AgentBuilder { /// The context in which the agent operates, by default this is the `DefaultContext`. pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder where Self: Clone, { self.context = Some(Arc::new(context) as Arc); self } /// Returns a mutable reference to the system prompt, if it is set. pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> { self.system_prompt.as_mut().and_then(Option::as_mut) } /// Disable the system prompt. pub fn no_system_prompt(&mut self) -> &mut Self { self.system_prompt = Some(None); self } /// Add a hook to the agent. pub fn add_hook(&mut self, hook: Hook) -> &mut Self { let hooks = self.hooks.get_or_insert_with(Vec::new); hooks.push(hook); self } /// Adds a tool to the agent pub fn add_tool(&mut self, tool: impl Tool + 'static) -> &mut Self { let tools = self.tools.get_or_insert_with(HashSet::new); if let Some(tool) = tools.replace(tool.boxed()) { tracing::debug!("Tool {} already exists, replacing", tool.name()); } self } /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed, /// before all will not trigger again. pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self { self.add_hook(Hook::BeforeAll(Box::new(hook))) } /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped /// and then starts again. The hook runs after any `before_all` hooks and before the /// `before_completion` hooks. pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self { self.add_hook(Hook::OnStart(Box::new(hook))) } /// Add a hook that runs when the agent receives a streaming response /// /// The response will always include both the current accumulated message and the delta /// /// This will set `self.streaming` to true, there is no need to set it manually for the default /// behaviour. pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self { self.streaming = Some(true); self.add_hook(Hook::OnStream(Box::new(hook))) } /// Add a hook that runs before each completion. pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self { self.add_hook(Hook::BeforeCompletion(Box::new(hook))) } /// Add a hook that runs after each tool. The `Result` is provided /// as mut, so the tool output can be fully modified. /// /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime /// what tool to interact with. pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self { self.add_hook(Hook::AfterTool(Box::new(hook))) } /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`. pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self { self.add_hook(Hook::BeforeTool(Box::new(hook))) } /// Add a hook that runs after each completion, before tool invocation and/or new messages. pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self { self.add_hook(Hook::AfterCompletion(Box::new(hook))) } /// Add a hook that runs after each completion, after tool invocations, right before a new loop /// might start pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self { self.add_hook(Hook::AfterEach(Box::new(hook))) } /// Add a hook that runs when a new message is added to the context. Note that each tool adds a /// separate message. pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self { self.add_hook(Hook::OnNewMessage(Box::new(hook))) } pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self { self.add_hook(Hook::OnStop(Box::new(hook))) } /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait. pub fn llm(&mut self, llm: &LLM) -> &mut Self { let boxed: Box = Box::new(llm.clone()) as Box; self.llm = Some(boxed); self } /// Removes the default `stop` tool from the agent. This allows you to add your own or use /// other methods to stop the agent. /// /// Note that you can also just override the tool if the name of the tool is `stop`. pub fn without_default_stop_tool(&mut self) -> &mut Self { self.clear_default_tools = Some(true); self } fn builder_default_tools(&self) -> HashSet> { if self.clear_default_tools.is_some_and(|b| b) { HashSet::new() } else { Agent::default_tools() } } /// Define the available tools for the agent. Tools must implement the `Tool` trait. /// /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools. pub fn tools>(&mut self, tools: I) -> &mut Self where TOOL: Into>, { self.tools = Some( self.builder_default_tools() .into_iter() .chain(tools.into_iter().map(Into::into)) .collect(), ); self } /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first /// time. /// /// Agents can have many toolboxes. pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self { let toolboxes = self.toolboxes.get_or_insert_with(Vec::new); toolboxes.push(Box::new(toolbox)); self } } impl Agent { /// Build a new agent pub fn builder() -> AgentBuilder { AgentBuilder::default() .tools(Agent::default_tools()) .to_owned() } /// The name of the agent pub fn name(&self) -> &str { &self.name } /// Default tools for the agent that it always includes /// Right now this is the `stop` tool, which allows the agent to stop itself. pub fn default_tools() -> HashSet> { HashSet::from([Stop::default().boxed()]) } /// Run the agent with a user message. The agent will loop completions, make tool calls, until /// no new messages are available. /// /// # Errors /// /// Errors if anything goes wrong, see `AgentError` for more details. #[tracing::instrument(skip_all, name = "agent.query", err)] pub async fn query(&mut self, query: impl Into) -> Result<(), AgentError> { let query = query .into() .render() .map_err(AgentError::FailedToRenderPrompt)?; self.run_agent(Some(query), false).await } /// Adds a tool to an agent at run time pub fn add_tool(&mut self, tool: Box) { if let Some(tool) = self.tools.replace(tool) { tracing::debug!("Tool {} already exists, replacing", tool.name()); } } /// Modify the tools of the agent at runtime /// /// Note that any mcp tools are added to the agent after the first start, and will only then /// also be available here. pub fn tools_mut(&mut self) -> &mut HashSet> { &mut self.tools } /// Run the agent with a user message once. /// /// # Errors /// /// Errors if anything goes wrong, see `AgentError` for more details. #[tracing::instrument(skip_all, name = "agent.query_once", err)] pub async fn query_once(&mut self, query: impl Into) -> Result<(), AgentError> { self.run_agent(Some(query), true).await } /// Run the agent with without user message. The agent will loop completions, make tool calls, /// until no new messages are available. /// /// # Errors /// /// Errors if anything goes wrong, see `AgentError` for more details. #[tracing::instrument(skip_all, name = "agent.run", err)] pub async fn run(&mut self) -> Result<(), AgentError> { self.run_agent(None::, false).await } /// Run the agent with without user message. The agent will loop completions, make tool calls, /// until /// /// # Errors /// /// Errors if anything goes wrong, see `AgentError` for more details. #[tracing::instrument(skip_all, name = "agent.run_once", err)] pub async fn run_once(&mut self) -> Result<(), AgentError> { self.run_agent(None::, true).await } /// Retrieve the message history of the agent /// /// # Errors /// /// Error if the message history cannot be retrieved, e.g. if the context is not set up or a /// connection fails pub async fn history(&self) -> Result, AgentError> { self.context .history() .await .map_err(AgentError::MessageHistoryError) } pub(crate) async fn run_agent( &mut self, maybe_query: Option>, just_once: bool, ) -> Result<(), AgentError> { let maybe_query = maybe_query .map(|q| q.into().render()) .transpose() .map_err(AgentError::FailedToRenderPrompt)?; if self.state.is_running() { return Err(AgentError::AlreadyRunning); } if self.state.is_pending() { if let Some(system_prompt) = &self.system_prompt { self.context .add_messages(vec![ChatMessage::System( system_prompt .to_prompt() .render() .map_err(AgentError::FailedToRenderSystemPrompt)?, )]) .await .map_err(AgentError::MessageHistoryError)?; } invoke_hooks!(BeforeAll, self); self.load_toolboxes().await?; } if let Some(query) = maybe_query { if cfg!(feature = "langfuse") { debug!(langfuse.input = query); } tracing::debug!("Queueing user message until tool outputs are recorded"); self.pending_user_messages.push_back(query); } self.invoke_pending_tool_calls().await?; if self.has_unfulfilled_tool_calls().await? { tracing::warn!( "Unfulfilled tool calls remain after invocation; agent/tool configuration is invalid" ); return Err(AgentError::UnfulfilledToolCalls); } self.flush_pending_user_messages().await?; invoke_hooks!(OnStart, self); self.state = state::State::Running; let mut loop_counter = 0; while let Some(messages) = self .context .next_completion() .await .map_err(AgentError::MessageHistoryError)? { if let Some(limit) = self.limit && loop_counter >= limit { tracing::warn!("Agent loop limit reached"); break; } // If the last message contains tool calls that have not been completed, // run the tools first if let Some(ChatMessage::Assistant(_, tool_calls)) = maybe_tool_call_without_output(&messages) && tool_calls .as_ref() .is_some_and(|tool_calls| !tool_calls.is_empty()) { tracing::debug!("Uncompleted tool calls found; invoking tools"); if let Some(tool_calls) = tool_calls.as_ref() { self.invoke_tools(tool_calls).await?; } // Move on to the next tick, so that the continue; } let result = self.step(&messages, loop_counter).await; if let Err(err) = result { self.stop_with_error(&err).await; tracing::error!(error = ?err, "Agent stopped with error {err}"); return Err(err); } if just_once || self.state.is_stopped() { break; } loop_counter += 1; } // If there are no new messages, ensure we update our state self.stop(StopReason::NoNewMessages).await; Ok(()) } #[tracing::instrument(skip(self, messages), err, fields(otel.name))] async fn step( &mut self, messages: &[ChatMessage], step_count: usize, ) -> Result<(), AgentError> { tracing::Span::current().record("otel.name", format!("step-{step_count}")); debug!( tools = ?self .tools .iter() .map(|t| t.name()) .collect::>() , "Running completion for agent with {} new messages", messages.len() ); let mut chat_completion_request = ChatCompletionRequest::builder() .messages(messages) .tool_specs(self.tools.iter().map(swiftide_core::Tool::tool_spec)) .build() .map_err(AgentError::FailedToBuildRequest)?; invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request); debug!( "Calling LLM with the following new messages:\n {}", self.context .current_new_messages() .await .map_err(AgentError::MessageHistoryError)? .iter() .map(ToString::to_string) .collect::>() .join(",\n") ); let mut response = if self.streaming { let mut last_response = None; let mut stream = self.llm.complete_stream(&chat_completion_request).await; while let Some(response) = stream.next().await { let response = response.map_err(AgentError::CompletionsFailed)?; invoke_hooks!(OnStream, self, &response); last_response = Some(response); } tracing::trace!(?last_response, "Streaming completed"); last_response.ok_or(AgentError::EmptyStream) } else { self.llm .complete(&chat_completion_request) .await .map_err(AgentError::CompletionsFailed) }?; // The arg preprocessor helps avoid common llm errors. // This must happen as early as possible response .tool_calls .as_deref_mut() .map(ArgPreprocessor::preprocess_tool_calls); invoke_hooks!(AfterCompletion, self, &mut response); let assistant_content = response.message.take(); let assistant_tool_calls = response.tool_calls.clone(); let has_assistant_message = assistant_content.is_some() || assistant_tool_calls .as_ref() .is_some_and(|calls| !calls.is_empty()); if let Some(reasoning_items) = response.reasoning.take() { if has_assistant_message { for item in reasoning_items { self.add_message(ChatMessage::Reasoning(item)).await?; } } else { tracing::debug!( "Skipping reasoning items because no assistant message or tool call was produced" ); } } if has_assistant_message { self.add_message(ChatMessage::Assistant( assistant_content, assistant_tool_calls, )) .await?; } if let Some(tool_calls) = response.tool_calls { self.invoke_tools(&tool_calls).await?; } invoke_hooks!(AfterEach, self); Ok(()) } async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> { tracing::debug!("LLM returned tool calls: {:?}", tool_calls); let mut handles = vec![]; for tool_call in tool_calls { let Some(tool) = self.find_tool_by_name(tool_call.name()) else { tracing::warn!("Tool {} not found", tool_call.name()); continue; }; tracing::info!("Calling tool `{}`", tool_call.name()); // let tool_args = tool_call.args().map(String::from); let context: Arc = Arc::clone(&self.context); invoke_hooks!(BeforeTool, self, &tool_call); let tool_span = tracing::info_span!( "tool", "otel.name" = format!("tool.{}", tool.name().as_ref()), ); let handle_tool_call = tool_call.clone(); let handle = tokio::spawn(async move { let handle_tool_call = handle_tool_call; let output = tool.invoke(&*context, &handle_tool_call) .await?; if cfg!(feature = "langfuse") { tracing::debug!( langfuse.output = %output, langfuse.input = handle_tool_call.args(), tool_name = tool.name().as_ref(), ); } else { tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call"); } Ok(output) }.instrument(tool_span.or_current())); handles.push((handle, tool_call)); } for (handle, tool_call) in handles { let mut output = handle .await .map_err(|err| AgentError::ToolFailedToJoin(tool_call.name().to_string(), err))?; invoke_hooks!(AfterTool, self, &tool_call, &mut output); if let Err(error) = output { let stop = self.tool_calls_over_limit(tool_call); if stop { tracing::error!( ?error, "Tool call failed, retry limit reached, stopping agent: {error}", ); } else { tracing::warn!( ?error, tool_call = ?tool_call, "Tool call failed, retrying", ); } self.add_message(ChatMessage::ToolOutput( tool_call.clone(), ToolOutput::fail(error.to_string()), )) .await?; if stop { self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned())) .await; return Err(error.into()); } continue; } let output = output?; self.handle_control_tools(tool_call, &output).await; // Feedback required leaves the tool call open // // It assumes a follow up invocation of the agent will have the feedback approved if !output.is_feedback_required() { self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output)) .await?; } } Ok(()) } fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> { self.hooks .iter() .filter(|h| hook_type == (*h).into()) .collect() } fn find_tool_by_name(&self, tool_name: &str) -> Option> { self.tools .iter() .find(|tool| tool.name() == tool_name) .cloned() } // Handle any tool specific output (e.g. stop) async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) { match output { ToolOutput::Stop(maybe_message) => { tracing::warn!("Stop tool called, stopping agent"); self.stop(StopReason::RequestedByTool( tool_call.clone(), maybe_message.clone(), )) .await; } ToolOutput::FeedbackRequired(maybe_payload) => { tracing::warn!("Feedback required, stopping agent"); self.stop(StopReason::FeedbackRequired { tool_call: tool_call.clone(), payload: maybe_payload.clone(), }) .await; } ToolOutput::AgentFailed(output) => { tracing::warn!("Agent failed, stopping agent"); self.stop(StopReason::AgentFailed(output.clone())).await; } _ => (), } } /// Retrieve the system prompt, if it is set. pub fn system_prompt(&self) -> Option<&SystemPrompt> { self.system_prompt.as_ref() } /// Retrieve a mutable reference to the system prompt, if it is set. /// /// Note that the system prompt is rendered only once, when the agent starts for the first time pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> { self.system_prompt.as_mut() } fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool { let mut s = DefaultHasher::new(); tool_call.hash(&mut s); let hash = s.finish(); if let Some(retries) = self.tool_retries_counter.get_mut(&hash) { let val = *retries >= self.tool_retry_limit; *retries += 1; val } else { self.tool_retries_counter.insert(hash, 1); false } } /// Add a message to the agent's context /// /// This will trigger a `OnNewMessage` hook if its present. /// /// If you want to add a message without triggering the hook, use the context directly. /// /// # Errors /// /// Errors if the message cannot be added to the context. With the default in memory context /// that is not supposed to happen. #[tracing::instrument(skip_all, fields(message = message.to_string()))] pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> { invoke_hooks!(OnNewMessage, self, &mut message); self.context .add_message(message) .await .map_err(AgentError::MessageHistoryError)?; Ok(()) } /// Tell the agent to stop. It will finish it's current loop and then stop. pub async fn stop(&mut self, reason: impl Into) { if self.state.is_stopped() { return; } let reason = reason.into(); invoke_hooks!(OnStop, self, reason.clone(), None); if cfg!(feature = "langfuse") { debug!(langfuse.output = serde_json::to_string_pretty(&reason).ok()); } self.state = state::State::Stopped(reason); } pub async fn stop_with_error(&mut self, error: &AgentError) { if self.state.is_stopped() { return; } invoke_hooks!(OnStop, self, StopReason::Error, Some(error)); self.state = state::State::Stopped(StopReason::Error); } /// Access the agent's context pub fn context(&self) -> &dyn AgentContext { &self.context } /// The agent is still running pub fn is_running(&self) -> bool { self.state.is_running() } /// The agent stopped pub fn is_stopped(&self) -> bool { self.state.is_stopped() } /// The agent has not (ever) started pub fn is_pending(&self) -> bool { self.state.is_pending() } /// Get a list of tools available to the agent pub fn tools(&self) -> &HashSet> { &self.tools } pub fn state(&self) -> &state::State { &self.state } pub fn stop_reason(&self) -> Option<&StopReason> { self.state.stop_reason() } async fn has_unfulfilled_tool_calls(&self) -> Result { let history = self .context .history() .await .map_err(AgentError::MessageHistoryError)?; Ok(maybe_tool_call_without_output(&history).is_some()) } async fn invoke_pending_tool_calls(&mut self) -> Result<(), AgentError> { let history = self .context .history() .await .map_err(AgentError::MessageHistoryError)?; if let Some(ChatMessage::Assistant(_, tool_calls)) = maybe_tool_call_without_output(&history) && tool_calls .as_ref() .is_some_and(|tool_calls| !tool_calls.is_empty()) && let Some(tool_calls) = tool_calls.as_ref() { self.invoke_tools(tool_calls).await?; } Ok(()) } async fn flush_pending_user_messages(&mut self) -> Result<(), AgentError> { if self.pending_user_messages.is_empty() { return Ok(()); } let messages = self .pending_user_messages .drain(..) .map(ChatMessage::new_user) .collect(); self.context .add_messages(messages) .await .map_err(AgentError::MessageHistoryError)?; Ok(()) } async fn load_toolboxes(&mut self) -> Result<(), AgentError> { for toolbox in &self.toolboxes { let tools = toolbox .available_tools() .await .map_err(AgentError::ToolBoxFailedToLoad)?; self.tools.extend(tools); } Ok(()) } } /// Reverse searches through messages, if it encounters a tool call before encountering an output, /// it will return the chat message with the tool calls, otherwise it returns None fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> { for message in messages.iter().rev() { if let ChatMessage::ToolOutput(..) = message { return None; } if let ChatMessage::Assistant(_, tool_calls) = message && tool_calls .as_ref() .is_some_and(|tool_calls| !tool_calls.is_empty()) { return Some(message); } } None } #[cfg(test)] mod tests { use serde::ser::Error; use swiftide_core::ToolFeedback; use swiftide_core::chat_completion::errors::ToolError; use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall}; use swiftide_core::test_utils::MockChatCompletion; use super::*; use crate::{ State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user, }; use crate::test_utils::{MockHook, MockTool}; #[test_log::test(tokio::test)] async fn test_agent_builder_defaults() { // Create a prompt let mock_llm = MockChatCompletion::new(); // Build the agent let agent = Agent::builder().llm(&mock_llm).build().unwrap(); // Check that the context is the default context // Check that the default tools are added assert!(agent.find_tool_by_name("stop").is_some()); // Check it does not allow duplicates let agent = Agent::builder() .tools([Stop::default(), Stop::default()]) .llm(&mock_llm) .build() .unwrap(); assert_eq!(agent.tools.len(), 1); // It should include the default tool if a different tool is provided let agent = Agent::builder() .tools([MockTool::new("mock_tool")]) .llm(&mock_llm) .build() .unwrap(); assert_eq!(agent.tools.len(), 2); assert!(agent.find_tool_by_name("mock_tool").is_some()); assert!(agent.find_tool_by_name("stop").is_some()); assert!(agent.context().history().await.unwrap().is_empty()); } #[test_log::test(tokio::test)] async fn test_agent_tool_calling_loop() { let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::new("mock_tool"); let chat_request = chat_request! { user!("Write a poem"); tools = [mock_tool.clone()] }; let mock_tool_response = chat_response! { "Roses are red"; tool_calls = ["mock_tool"] }; mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response)); let chat_request = chat_request! { user!("Write a poem"), assistant!("Roses are red", ["mock_tool"]), tool_output!("mock_tool", "Great!"); tools = [mock_tool.clone()] }; let stop_response = chat_response! { "Roses are red"; tool_calls = ["stop"] }; mock_llm.expect_complete(chat_request, Ok(stop_response)); mock_tool.expect_invoke_ok("Great!".into(), None); let mut agent = Agent::builder() .tools([mock_tool]) .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); agent.query(prompt).await.unwrap(); } #[test_log::test(tokio::test)] async fn test_agent_tool_run_once() { let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::default(); let chat_request = chat_request! { system!("My system prompt"), user!("Write a poem"); tools = [mock_tool.clone()] }; let mock_tool_response = chat_response! { "Roses are red"; tool_calls = ["mock_tool"] }; mock_tool.expect_invoke_ok("Great!".into(), None); mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response)); let mut agent = Agent::builder() .tools([mock_tool]) .system_prompt("My system prompt") .llm(&mock_llm) .build() .unwrap(); agent.query_once(prompt).await.unwrap(); } #[test_log::test(tokio::test)] async fn test_agent_tool_via_toolbox_run_once() { let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::default(); let chat_request = chat_request! { system!("My system prompt"), user!("Write a poem"); tools = [mock_tool.clone()] }; let mock_tool_response = chat_response! { "Roses are red"; tool_calls = ["mock_tool"] }; mock_tool.expect_invoke_ok("Great!".into(), None); mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response)); let mut agent = Agent::builder() .add_toolbox(vec![mock_tool.boxed()]) .system_prompt("My system prompt") .llm(&mock_llm) .build() .unwrap(); agent.query_once(prompt).await.unwrap(); } #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_multiple_tool_calls() { let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::new("mock_tool1"); let mock_tool2 = MockTool::new("mock_tool2"); let chat_request = chat_request! { system!("My system prompt"), user!("Write a poem"); tools = [mock_tool.clone(), mock_tool2.clone()] }; let mock_tool_response = chat_response! { "Roses are red"; tool_calls = ["mock_tool1", "mock_tool2"] }; dbg!(&chat_request); mock_tool.expect_invoke_ok("Great!".into(), None); mock_tool2.expect_invoke_ok("Great!".into(), None); mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response)); let chat_request = chat_request! { system!("My system prompt"), user!("Write a poem"), assistant!("Roses are red", ["mock_tool1", "mock_tool2"]), tool_output!("mock_tool1", "Great!"), tool_output!("mock_tool2", "Great!"); tools = [mock_tool.clone(), mock_tool2.clone()] }; let mock_tool_response = chat_response! { "Ok!"; tool_calls = ["stop"] }; mock_llm.expect_complete(chat_request, Ok(mock_tool_response)); let mut agent = Agent::builder() .tools([mock_tool, mock_tool2]) .system_prompt("My system prompt") .llm(&mock_llm) .build() .unwrap(); agent.query(prompt).await.unwrap(); } #[test_log::test(tokio::test)] async fn test_agent_state_machine() { let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let chat_request = chat_request! { user!("Write a poem"); tools = [] }; let mock_tool_response = chat_response! { "Roses are red"; tool_calls = [] }; mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response)); let mut agent = Agent::builder() .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); // Agent has never run and is pending assert!(agent.state.is_pending()); agent.query_once(prompt).await.unwrap(); // Agent is stopped, there might be more messages assert!(agent.state.is_stopped()); } #[test_log::test(tokio::test)] async fn test_summary() { let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let mock_tool_response = chat_response! { "Roses are red"; tool_calls = [] }; let expected_chat_request = chat_request! { system!("My system prompt"), user!("Write a poem"); tools = [] }; mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone())); let mut agent = Agent::builder() .system_prompt("My system prompt") .llm(&mock_llm) .build() .unwrap(); agent.query_once(prompt).await.unwrap(); agent .context .add_message(ChatMessage::new_summary("Summary")) .await .unwrap(); let expected_chat_request = chat_request! { system!("My system prompt"), summary!("Summary"), user!("Write another poem"); tools = [] }; mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone())); agent.query_once("Write another poem").await.unwrap(); agent .context .add_message(ChatMessage::new_summary("Summary 2")) .await .unwrap(); let expected_chat_request = chat_request! { system!("My system prompt"), summary!("Summary 2"), user!("Write a third poem"); tools = [] }; mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response)); agent.query_once("Write a third poem").await.unwrap(); } #[test_log::test(tokio::test)] async fn test_agent_hooks() { let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned(); let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned(); let mock_before_completion = MockHook::new("before_completion") .expect_calls(2) .to_owned(); let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned(); let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned(); let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned(); let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned(); // Once for mock tool and once for stop let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned(); let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned(); let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::default(); let chat_request = chat_request! { user!("Write a poem"); tools = [mock_tool.clone()] }; let mock_tool_response = chat_response! { "Roses are red"; tool_calls = ["mock_tool"] }; mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response)); let chat_request = chat_request! { user!("Write a poem"), assistant!("Roses are red", ["mock_tool"]), tool_output!("mock_tool", "Great!"); tools = [mock_tool.clone()] }; let stop_response = chat_response! { "Roses are red"; tool_calls = ["stop"] }; mock_llm.expect_complete(chat_request, Ok(stop_response)); mock_tool.expect_invoke_ok("Great!".into(), None); let mut agent = Agent::builder() .tools([mock_tool]) .llm(&mock_llm) .no_system_prompt() .before_all(mock_before_all.hook_fn()) .on_start(mock_on_start_fn.on_start_fn()) .before_completion(mock_before_completion.before_completion_fn()) .before_tool(mock_before_tool.before_tool_fn()) .after_completion(mock_after_completion.after_completion_fn()) .after_tool(mock_after_tool.after_tool_fn()) .after_each(mock_after_each.hook_fn()) .on_new_message(mock_on_message.message_hook_fn()) .on_stop(mock_on_stop.stop_hook_fn()) .build() .unwrap(); agent.query(prompt).await.unwrap(); } #[test_log::test(tokio::test)] async fn test_agent_loop_limit() { let prompt = "Generate content"; // Example prompt let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::new("mock_tool"); let chat_request = chat_request! { user!(prompt); tools = [mock_tool.clone()] }; mock_tool.expect_invoke_ok("Great!".into(), None); let mock_tool_response = chat_response! { "Some response"; tool_calls = ["mock_tool"] }; // Set expectations for the mock LLM responses mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone())); // // Response for terminating the loop let stop_response = chat_response! { "Final response"; tool_calls = ["stop"] }; mock_llm.expect_complete(chat_request, Ok(stop_response)); let mut agent = Agent::builder() .tools([mock_tool]) .llm(&mock_llm) .no_system_prompt() .limit(1) // Setting the loop limit to 1 .build() .unwrap(); // Run the agent agent.query(prompt).await.unwrap(); // Assert that the remaining message is still in the queue let remaining = mock_llm.expectations.lock().unwrap().pop(); assert!(remaining.is_some()); // Assert that the agent is stopped after reaching the loop limit assert!(agent.is_stopped()); } #[test_log::test(tokio::test)] async fn test_tool_retry_mechanism() { let prompt = "Execute my tool"; let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::new("retry_tool"); // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an // error mock_tool.expect_invoke( Err(ToolError::WrongArguments(serde_json::Error::custom( "missing `query`", ))), None, ); mock_tool.expect_invoke( Err(ToolError::WrongArguments(serde_json::Error::custom( "missing `query`", ))), None, ); let chat_request = chat_request! { user!(prompt); tools = [mock_tool.clone()] }; let retry_response = chat_response! { "First failing attempt"; tool_calls = ["retry_tool"] }; mock_llm.expect_complete(chat_request.clone(), Ok(retry_response)); let chat_request = chat_request! { user!(prompt), assistant!("First failing attempt", ["retry_tool"]), tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`"); tools = [mock_tool.clone()] }; let will_fail_response = chat_response! { "Finished execution"; tool_calls = ["retry_tool"] }; mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response)); let mut agent = Agent::builder() .tools([mock_tool]) .llm(&mock_llm) .no_system_prompt() .tool_retry_limit(1) // The test relies on a limit of 2 retries. .build() .unwrap(); // Run the agent let result = agent.query(prompt).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("missing `query`")); assert!(agent.is_stopped()); } #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_streaming() { let prompt = "Generate content"; // Example prompt let mock_llm = MockChatCompletion::new(); let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned(); let chat_request = chat_request! { user!(prompt); tools = [] }; let response = chat_response! { "one two three"; tool_calls = ["stop"] }; // Set expectations for the mock LLM responses mock_llm.expect_complete(chat_request, Ok(response)); let mut agent = Agent::builder() .llm(&mock_llm) .on_stream(on_stream_fn.on_stream_fn()) .no_system_prompt() .build() .unwrap(); // Run the agent agent.query(prompt).await.unwrap(); tracing::debug!("Agent finished running"); // Assert that the agent is stopped after reaching the loop limit assert!(agent.is_stopped()); } #[test_log::test(tokio::test)] async fn test_recovering_agent_existing_history() { // First, let's run an agent let prompt = "Write a poem"; let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::new("mock_tool"); let chat_request = chat_request! { user!("Write a poem"); tools = [mock_tool.clone()] }; let mock_tool_response = chat_response! { "Roses are red"; tool_calls = ["mock_tool"] }; mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response)); let chat_request = chat_request! { user!("Write a poem"), assistant!("Roses are red", ["mock_tool"]), tool_output!("mock_tool", "Great!"); tools = [mock_tool.clone()] }; let stop_response = chat_response! { "Roses are red"; tool_calls = ["stop"] }; mock_llm.expect_complete(chat_request, Ok(stop_response)); mock_tool.expect_invoke_ok("Great!".into(), None); let mut agent = Agent::builder() .tools([mock_tool.clone()]) .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); agent.query(prompt).await.unwrap(); // Let's retrieve the history of the agent let history = agent.history().await.unwrap(); // Store it as a string somewhere let serialized = serde_json::to_string(&history).unwrap(); // Retrieve it let history: Vec = serde_json::from_str::>(&serialized) .unwrap() .into_iter() .map(|message| message.to_owned()) .collect(); // Build a context from the history let context = DefaultContext::default() .with_existing_messages(history) .await .unwrap() .to_owned(); let stop_output = ToolOutput::stop(); let expected_chat_request = chat_request! { user!("Write a poem"), assistant!("Roses are red", ["mock_tool"]), tool_output!("mock_tool", "Great!"), assistant!("Roses are red", ["stop"]), tool_output!("stop", stop_output), user!("Try again!"); tools = [mock_tool.clone()] }; let stop_response = chat_response! { "Really stopping now"; tool_calls = ["stop"] }; mock_llm.expect_complete(expected_chat_request, Ok(stop_response)); let mut agent = Agent::builder() .context(context) .tools([mock_tool]) .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); agent.query_once("Try again!").await.unwrap(); } #[test_log::test(tokio::test)] async fn test_agent_with_approval_required_tool() { use super::*; use crate::tools::control::ApprovalRequired; use crate::{assistant, chat_request, chat_response, user}; use swiftide_core::chat_completion::ToolCall; // Step 1: Build a tool that needs approval. let mock_tool = MockTool::default(); mock_tool.expect_invoke_ok("Great!".into(), None); let approval_tool = ApprovalRequired(mock_tool.boxed()); // Step 2: Set up the mock LLM. let mock_llm = MockChatCompletion::new(); let chat_req1 = chat_request! { user!("Request with approval"); tools = [approval_tool.clone()] }; let chat_resp1 = chat_response! { "Completion message"; tool_calls = ["mock_tool"] }; mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1)); // The response will include the previous request, but no tool output // from the required tool let chat_req2 = chat_request! { user!("Request with approval"), assistant!("Completion message", ["mock_tool"]), tool_output!("mock_tool", "Great!"); // Simulate feedback required output tools = [approval_tool.clone()] }; let chat_resp2 = chat_response! { "Post-feedback message"; tool_calls = ["stop"] }; mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2)); // Step 3: Wire up the agent. let mut agent = Agent::builder() .tools([approval_tool]) .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); // Step 4: Run agent to trigger approval. agent.query_once("Request with approval").await.unwrap(); assert!(matches!( agent.state, crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. }) )); let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone() else { panic!("Expected feedback required"); }; // Step 5: Simulate feedback, run again and assert finish. agent .context .feedback_received(&tool_call, &ToolFeedback::approved()) .await .unwrap(); tracing::debug!("running after approval"); agent.run_once().await.unwrap(); assert!(agent.is_stopped()); } #[test_log::test(tokio::test)] async fn test_agent_with_approval_required_tool_denied() { use super::*; use crate::tools::control::ApprovalRequired; use crate::{assistant, chat_request, chat_response, user}; use swiftide_core::chat_completion::ToolCall; // Step 1: Build a tool that needs approval. let mock_tool = MockTool::default(); let approval_tool = ApprovalRequired(mock_tool.boxed()); // Step 2: Set up the mock LLM. let mock_llm = MockChatCompletion::new(); let chat_req1 = chat_request! { user!("Request with approval"); tools = [approval_tool.clone()] }; let chat_resp1 = chat_response! { "Completion message"; tool_calls = ["mock_tool"] }; mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1)); // The response will include the previous request, but no tool output // from the required tool let chat_req2 = chat_request! { user!("Request with approval"), assistant!("Completion message", ["mock_tool"]), tool_output!("mock_tool", "This tool call was refused"); // Simulate feedback required output tools = [approval_tool.clone()] }; let chat_resp2 = chat_response! { "Post-feedback message"; tool_calls = ["stop"] }; mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2)); // Step 3: Wire up the agent. let mut agent = Agent::builder() .tools([approval_tool]) .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); // Step 4: Run agent to trigger approval. agent.query_once("Request with approval").await.unwrap(); assert!(matches!( agent.state, crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. }) )); let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone() else { panic!("Expected feedback required"); }; // Step 5: Simulate feedback, run again and assert finish. agent .context .feedback_received(&tool_call, &ToolFeedback::refused()) .await .unwrap(); tracing::debug!("running after approval"); agent.run_once().await.unwrap(); let history = agent.context().history().await.unwrap(); history .iter() .rfind(|m| { let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else { return false; }; msg.contains("refused") }) .expect("Could not find refusal message"); assert!(agent.is_stopped()); } #[test_log::test(tokio::test)] async fn test_defers_user_message_until_pending_tool_calls_complete() { let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::default(); mock_tool.expect_invoke_ok("Tool done".into(), None); let context = DefaultContext::default() .with_existing_messages(vec![user!("Hello"), assistant!("Need tool", ["mock_tool"])]) .await .unwrap() .to_owned(); let expected_request = chat_request! { user!("Hello"), assistant!("Need tool", ["mock_tool"]), tool_output!("mock_tool", "Tool done"), user!("Next"); tools = [mock_tool.clone()] }; let response = chat_response! { "All set"; tool_calls = ["stop"] }; mock_llm.expect_complete(expected_request, Ok(response)); let mut agent = Agent::builder() .context(context) .tools([mock_tool]) .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); agent.query_once("Next").await.unwrap(); } #[test_log::test(tokio::test)] async fn test_removing_default_stop_tool() { let mock_llm = MockChatCompletion::new(); let mock_tool = MockTool::new("mock_tool"); // Build agent with without_default_stop_tool let agent = Agent::builder() .without_default_stop_tool() .tools([mock_tool.clone()]) .llm(&mock_llm) .no_system_prompt() .build() .unwrap(); // Check that "stop" tool is NOT included assert!(agent.find_tool_by_name("stop").is_none()); // Check that our provided tool is still present assert!(agent.find_tool_by_name("mock_tool").is_some()); } } ================================================ FILE: swiftide-agents/src/default_context.rs ================================================ //! Manages agent history and provides an interface for the external world //! //! This is the default for agents. It is fully async and shareable between agents. //! //! By default uses the `LocalExecutor` for tool execution. //! //! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the //! system prompt. This is useful for maintaining focus in long conversations or managing token //! limits. use std::{ collections::HashMap, sync::{ Arc, Mutex, atomic::{AtomicUsize, Ordering}, }, }; use anyhow::Result; use async_trait::async_trait; use swiftide_core::{ AgentContext, Command, CommandError, CommandOutput, MessageHistory, ToolExecutor, }; use swiftide_core::{ ToolFeedback, chat_completion::{ChatMessage, ToolCall}, }; use crate::tools::local_executor::LocalExecutor; // TODO: Remove unit as executor and implement a local executor instead #[derive(Clone)] pub struct DefaultContext { /// Responsible for managing the conversation history /// /// By default, this is a `Arc>>`. message_history: Arc, /// Index in the conversation history where the next completion will start completions_ptr: Arc, /// Index in the conversation history where the current completion started /// Allows for retrieving only new messages since the last completion current_completions_ptr: Arc, /// The executor used to run tools. I.e. local, remote, docker tool_executor: Arc, /// Stop if last message is from the assistant stop_on_assistant: bool, feedback_received: Arc>>, } impl Default for DefaultContext { fn default() -> Self { DefaultContext { message_history: Arc::new(Mutex::new(Vec::new())), completions_ptr: Arc::new(AtomicUsize::new(0)), current_completions_ptr: Arc::new(AtomicUsize::new(0)), tool_executor: Arc::new(LocalExecutor::default()) as Arc, stop_on_assistant: true, feedback_received: Arc::new(Mutex::new(HashMap::new())), } } } impl std::fmt::Debug for DefaultContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("DefaultContext") .field("completion_history", &self.message_history) .field("completions_ptr", &self.completions_ptr) .field("current_completions_ptr", &self.current_completions_ptr) .field("tool_executor", &"Arc") .field("stop_on_assistant", &self.stop_on_assistant) .finish() } } impl DefaultContext { /// Create a new context with a custom executor pub fn from_executor>>(executor: T) -> DefaultContext { DefaultContext { tool_executor: executor.into(), ..Default::default() } } /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new /// tool calls, summaries or user messages) pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self { self.stop_on_assistant = stop; self } pub fn with_message_history(&mut self, backend: impl MessageHistory + 'static) -> &mut Self { self.message_history = Arc::new(backend) as Arc; self } /// Build a context from an existing message history /// /// # Errors /// /// Errors if the message history cannot be extended /// /// # Panics /// /// Panics if the inner mutex is poisoned pub async fn with_existing_messages>( &mut self, message_history: I, ) -> Result<&mut Self> { self.message_history .overwrite(message_history.into_iter().collect()) .await?; Ok(self) } /// Add existing tool feedback to the context /// /// # Panics /// /// Panics if the inner mutex is poisoned pub fn with_tool_feedback(&mut self, feedback: impl Into>) { self.feedback_received .lock() .unwrap() .extend(feedback.into()); } } #[async_trait] impl AgentContext for DefaultContext { /// Retrieve messages for the next completion async fn next_completion(&self) -> Result>> { let history = self.message_history.history().await?; let mut current = self.completions_ptr.load(Ordering::SeqCst); // handle out of bounds; if current > length, reset current to 0 // if length is 0, return None if history.is_empty() { tracing::debug!("No messages in history for completion"); return Ok(None); } if current > history.len() { tracing::warn!( current, len = history.len(), "Completions index was higher than history length, resetting to 0; this might be a bug" ); self.completions_ptr.store(0, Ordering::SeqCst); self.current_completions_ptr.store(0, Ordering::SeqCst); current = 0; } if history[current..].is_empty() || (self.stop_on_assistant && matches!(history.last(), Some(ChatMessage::Assistant(..))) && self.feedback_received.lock().unwrap().is_empty()) { tracing::debug!(?history, "No new messages for completion"); Ok(None) } else { let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst); self.current_completions_ptr .store(previous, Ordering::SeqCst); Ok(Some(filter_messages_since_summary(history))) } } /// Returns the messages the agent is currently completing on async fn current_new_messages(&self) -> Result> { let current = self.current_completions_ptr.load(Ordering::SeqCst); let end = self.completions_ptr.load(Ordering::SeqCst); let history = self.message_history.history().await?; Ok(filter_messages_since_summary( history[current..end].to_vec(), )) } /// Retrieve all messages in the conversation history async fn history(&self) -> Result> { self.message_history.history().await } /// Add multiple messages to the conversation history async fn add_messages(&self, messages: Vec) -> Result<()> { self.message_history.extend_owned(messages).await } /// Add a single message to the conversation history async fn add_message(&self, item: ChatMessage) -> Result<()> { self.message_history.push_owned(item).await } /// Execute a command in the tool executor async fn exec_cmd(&self, cmd: &Command) -> Result { self.tool_executor.exec_cmd(cmd).await } fn executor(&self) -> &Arc { &self.tool_executor } /// Pops the last messages up until the previous completion /// /// LLMs failing completion for various reasons is unfortunately a common occurrence /// This gives a way to redrive the last completion in a generic way async fn redrive(&self) -> Result<()> { let mut history = self.message_history.history().await?; let previous = self.current_completions_ptr.load(Ordering::SeqCst); let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst); // delete everything after the last completion history.truncate(redrive_ptr); self.message_history.overwrite(history).await?; Ok(()) } async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option { // If feedback is present, return true with the optional payload, // and remove it // otherwise return false let mut lock = self.feedback_received.lock().unwrap(); lock.remove(tool_call) } async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> { let mut lock = self.feedback_received.lock().unwrap(); // Set the message counter one back so that on a next try, the agent can resume by // trying the tool calls first. Only does this if there are no other approvals if lock.is_empty() { let previous = self.current_completions_ptr.load(Ordering::SeqCst); self.completions_ptr.swap(previous, Ordering::SeqCst); } tracing::debug!(?tool_call, context = ?self, "feedback received"); lock.insert(tool_call.clone(), feedback.clone()); Ok(()) } /// Replace the entire conversation history async fn replace_history(&self, items: Vec) -> Result<()> { self.message_history.overwrite(items).await?; self.completions_ptr.store(0, Ordering::SeqCst); self.current_completions_ptr.store(0, Ordering::SeqCst); Ok(()) } } fn filter_messages_since_summary(messages: Vec) -> Vec { let mut summary_found = false; let mut messages = messages .into_iter() .rev() .filter(|m| { if summary_found { return matches!(m, ChatMessage::System(_)); } if let ChatMessage::Summary(_) = m { summary_found = true; } true }) .collect::>(); messages.reverse(); messages } #[cfg(test)] mod tests { use crate::{assistant, tool_output, user}; use super::*; use swiftide_core::chat_completion::{ChatMessage, ToolCall}; #[tokio::test] async fn test_iteration_tracking() { let mut context = DefaultContext::default(); // Record initial chat messages context .add_messages(vec![ ChatMessage::System("You are awesome".into()), ChatMessage::User("Hello".into()), ]) .await .unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 2); assert!(context.next_completion().await.unwrap().is_none()); context .add_messages(vec![assistant!("Hey?"), user!("How are you?")]) .await .unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 4); assert!(context.next_completion().await.unwrap().is_none()); // If the last message is from the assistant, we should not get any more completions context .add_messages(vec![assistant!("I am fine")]) .await .unwrap(); assert!(context.next_completion().await.unwrap().is_none()); context.with_stop_on_assistant(false); assert!(context.next_completion().await.unwrap().is_some()); } #[tokio::test] async fn test_should_complete_after_tool_call() { let context = DefaultContext::default(); // Record initial chat messages context .add_messages(vec![ ChatMessage::System("You are awesome".into()), ChatMessage::User("Hello".into()), ]) .await .unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 2); assert_eq!(context.current_new_messages().await.unwrap().len(), 2); assert!(context.next_completion().await.unwrap().is_none()); context .add_messages(vec![ assistant!("Hey?", ["test"]), tool_output!("test", "Hoi"), ]) .await .unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(context.current_new_messages().await.unwrap().len(), 2); assert_eq!(messages.len(), 4); assert!(context.next_completion().await.unwrap().is_none()); } #[tokio::test] async fn test_filters_messages_before_summary() { let messages = vec![ ChatMessage::System("System message".into()), ChatMessage::User("Hello".into()), ChatMessage::new_assistant(Some("Hello there"), None), ChatMessage::Summary("Summary message".into()), ChatMessage::User("This should be ignored".into()), ]; let context = DefaultContext::default(); // Record initial chat messages context.add_messages(messages).await.unwrap(); let new_messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(new_messages.len(), 3); assert!(matches!(new_messages[0], ChatMessage::System(_))); assert!(matches!(new_messages[1], ChatMessage::Summary(_))); assert!(matches!(new_messages[2], ChatMessage::User(_))); let current_new_messages = context.current_new_messages().await.unwrap(); assert_eq!(current_new_messages.len(), 3); assert!(matches!(current_new_messages[0], ChatMessage::System(_))); assert!(matches!(current_new_messages[1], ChatMessage::Summary(_))); assert!(matches!(current_new_messages[2], ChatMessage::User(_))); assert!(context.next_completion().await.unwrap().is_none()); } #[tokio::test] async fn test_filters_messages_before_summary_with_assistant_last() { let messages = vec![ ChatMessage::System("System message".into()), ChatMessage::User("Hello".into()), ChatMessage::new_assistant(Some("Hello there"), None), ]; let mut context = DefaultContext::default(); context.with_stop_on_assistant(false); // Record initial chat messages context.add_messages(messages).await.unwrap(); let new_messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(new_messages.len(), 3); assert!(matches!(new_messages[0], ChatMessage::System(_))); assert!(matches!(new_messages[1], ChatMessage::User(_))); assert!(matches!(new_messages[2], ChatMessage::Assistant(..))); context .add_message(ChatMessage::Summary("Summary message 1".into())) .await .unwrap(); let new_messages = context.next_completion().await.unwrap().unwrap(); dbg!(&new_messages); assert_eq!(new_messages.len(), 2); assert!(matches!(new_messages[0], ChatMessage::System(_))); assert_eq!( new_messages[1], ChatMessage::Summary("Summary message 1".into()) ); assert!(context.next_completion().await.unwrap().is_none()); let messages = vec![ ChatMessage::User("Hello again".into()), ChatMessage::new_assistant(Some("Hello there again"), None), ]; context.add_messages(messages).await.unwrap(); let new_messages = context.next_completion().await.unwrap().unwrap(); assert!(matches!(new_messages[0], ChatMessage::System(_))); assert_eq!( new_messages[1], ChatMessage::Summary("Summary message 1".into()) ); assert_eq!(new_messages[2], ChatMessage::User("Hello again".into())); assert_eq!( new_messages[3], ChatMessage::new_assistant(Some("Hello there again".to_string()), None) ); context .add_message(ChatMessage::Summary("Summary message 2".into())) .await .unwrap(); let new_messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(new_messages.len(), 2); assert!(matches!(new_messages[0], ChatMessage::System(_))); assert_eq!( new_messages[1], ChatMessage::Summary("Summary message 2".into()) ); } #[tokio::test] async fn test_redrive() { let context = DefaultContext::default(); // Record initial chat messages context .add_messages(vec![ ChatMessage::System("System message".into()), ChatMessage::User("Hello".into()), ]) .await .unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 2); assert!(context.next_completion().await.unwrap().is_none()); context.redrive().await.unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 2); context .add_messages(vec![ChatMessage::User("Hey?".into())]) .await .unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 3); assert!(context.next_completion().await.unwrap().is_none()); context.redrive().await.unwrap(); // Add more messages context .add_messages(vec![ChatMessage::User("How are you?".into())]) .await .unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 4); assert!(context.next_completion().await.unwrap().is_none()); // Redrive should remove the last set of messages dbg!(&context); context.redrive().await.unwrap(); dbg!(&context); // We just redrove with the same messages let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 4); assert!(context.next_completion().await.unwrap().is_none()); // Add more messages context .add_messages(vec![ ChatMessage::User("How are you really?".into()), ChatMessage::User("How are you really?".into()), ]) .await .unwrap(); // This should remove any additional messages context.redrive().await.unwrap(); // We just redrove with the same messages let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 4); assert!(context.next_completion().await.unwrap().is_none()); // Redrive again context.redrive().await.unwrap(); let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 4); assert!(context.next_completion().await.unwrap().is_none()); } #[tokio::test] async fn test_next_completion_empty_history() { let context = DefaultContext::default(); let next = context.next_completion().await; assert!(next.unwrap().is_none()); } #[tokio::test] async fn test_next_completion_out_of_bounds_ptr() { let context = DefaultContext::default(); context .add_messages(vec![ ChatMessage::System("System".into()), ChatMessage::User("Hi".into()), ]) .await .unwrap(); // Set completions_ptr beyond the length of messages context .completions_ptr .store(10, std::sync::atomic::Ordering::SeqCst); // Should reset the pointer and return the full messages let messages = context.next_completion().await.unwrap().unwrap(); assert_eq!(messages.len(), 2); // Second call should be empty again assert!(context.next_completion().await.unwrap().is_none()); } #[tokio::test] async fn test_replace_history_replaces_and_resets_pointers() { let mut context = DefaultContext::default(); context.with_stop_on_assistant(false); // Add some initial messages context .add_messages(vec![ ChatMessage::System("Initial".into()), ChatMessage::User("Hello".into()), ChatMessage::new_assistant(Some("Hi."), None), ]) .await .unwrap(); // Consume the messages so pointers are moved let orig = context.next_completion().await.unwrap().unwrap(); assert_eq!(orig.len(), 3); assert!(context.next_completion().await.unwrap().is_none()); // Replace history with a new set let new_msgs = vec![ ChatMessage::System("System2".into()), ChatMessage::User("User2".into()), ]; context.replace_history(new_msgs.clone()).await.unwrap(); // After replacement, next_completion should return only the new messages let replaced = context.next_completion().await.unwrap().unwrap(); assert_eq!(replaced, new_msgs); // Next call should yield None again assert!(context.next_completion().await.unwrap().is_none()); } } ================================================ FILE: swiftide-agents/src/errors.rs ================================================ use swiftide_core::chat_completion::{ ChatCompletionRequestBuilderError, errors::{LanguageModelError, ToolError}, }; use thiserror::Error; use tokio::task::JoinError; #[derive(Error, Debug)] pub enum AgentError { #[error("Agent is already running")] AlreadyRunning, #[error("Failed to render system prompt {0:#}")] FailedToRenderSystemPrompt(anyhow::Error), #[error("Failed to build chat completion request {0:#}")] FailedToBuildRequest(ChatCompletionRequestBuilderError), #[error("Error from LLM when running completions {0:#}")] CompletionsFailed(LanguageModelError), #[error(transparent)] ToolError(#[from] ToolError), #[error("Failed waiting for tool to finish {0:?}")] ToolFailedToJoin(String, JoinError), #[error("Failed to load tools from toolbox {0:#}")] ToolBoxFailedToLoad(anyhow::Error), #[error("Chat completion stream was empty")] EmptyStream, #[error("Failed to render prompt {0:#}")] FailedToRenderPrompt(anyhow::Error), #[error("Error with message history {0:#}")] MessageHistoryError(anyhow::Error), #[error("Unfulfilled tool calls remain after invocation")] UnfulfilledToolCalls, } ================================================ FILE: swiftide-agents/src/hooks.rs ================================================ //! Hooks are functions that are called at specific points in the agent lifecycle. //! //! //! Since rust does not have async closures, hooks have to return a boxed, pinned async block //! themselves. //! //! # Example //! //! ```no_run //! # use swiftide_core::{AgentContext, chat_completion::ChatMessage}; //! # use swiftide_agents::Agent; //! # fn test() { //! # let mut agent = swiftide_agents::Agent::builder(); //! agent.before_all(move |agent: &Agent| { //! Box::pin(async move { //! agent.context().add_message(ChatMessage::new_user("Hello, world")).await; //! Ok(()) //! }) //! }); //! # } //! ``` //! Rust has a long outstanding issue where it captures outer lifetimes when returning an impl //! that also has lifetimes, see [this issue](https://github.com/rust-lang/rust/issues/42940) //! //! This can happen if you write a method like `fn return_hook(&self) -> impl HookFn`, where the //! owner also has a lifetime. //! The trick is to set an explicit lifetime on self, and hook, where self must outlive the hook. //! //! # Example //! //! ```no_run //! # use swiftide_core::{AgentContext}; //! # use swiftide_agents::hooks::BeforeAllFn; //! # use swiftide_agents::Agent; //! struct SomeHook<'thing> { //! thing: &'thing str //! } //! //! impl<'thing> SomeHook<'thing> { //! fn return_hook<'tool>(&'thing self) -> impl BeforeAllFn + 'tool where 'thing: 'tool { //! move |_: &Agent| { //! Box::pin(async move {{ Ok(())}}) //! } //! } //! } use anyhow::Result; use std::{future::Future, pin::Pin}; use dyn_clone::DynClone; use swiftide_core::chat_completion::{ ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolOutput, errors::ToolError, }; use crate::{Agent, errors::AgentError, state::StopReason}; pub trait BeforeAllFn: for<'a> Fn(&'a Agent) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(BeforeAllFn); pub trait AfterEachFn: for<'a> Fn(&'a Agent) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(AfterEachFn); pub trait BeforeCompletionFn: for<'a> Fn( &'a Agent, &mut ChatCompletionRequest<'_>, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(BeforeCompletionFn); pub trait AfterCompletionFn: for<'a> Fn( &'a Agent, &mut ChatCompletionResponse, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(AfterCompletionFn); /// Hooks that are called after each tool pub trait AfterToolFn: for<'tool> Fn( &'tool Agent, &ToolCall, &'tool mut Result, ) -> Pin> + Send + 'tool>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(AfterToolFn); /// Hooks that are called before each tool pub trait BeforeToolFn: for<'a> Fn(&'a Agent, &ToolCall) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(BeforeToolFn); /// Hooks that are called when a new message is added to the `AgentContext` pub trait MessageHookFn: for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(MessageHookFn); /// Hooks that are called when the agent starts, either from pending or stopped pub trait OnStartFn: for<'a> Fn(&'a Agent) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(OnStartFn); /// Hooks that are called when the agent stop pub trait OnStopFn: for<'a> Fn( &'a Agent, StopReason, Option<&AgentError>, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(OnStopFn); pub trait OnStreamFn: for<'a> Fn( &'a Agent, &ChatCompletionResponse, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } dyn_clone::clone_trait_object!(OnStreamFn); /// Wrapper around the different types of hooks #[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)] #[strum_discriminants(name(HookTypes), derive(strum_macros::Display))] pub enum Hook { /// Runs only once for the agent when it starts BeforeAll(Box), /// Runs before every completion, yielding a mutable reference to the completion request BeforeCompletion(Box), /// Runs after every completion, yielding a mutable reference to the completion response AfterCompletion(Box), /// Runs before every tool call, yielding a reference to the tool call BeforeTool(Box), /// Runs after every tool call, yielding a reference to the tool call and a mutable result AfterTool(Box), /// Runs after all tools have completed and a single completion has been made AfterEach(Box), /// Runs when a new message is added to the `AgentContext`, yielding a mutable reference to the /// message. This is only triggered when the message is added by the agent. OnNewMessage(Box), /// Runs when the agent starts, either from pending or stopped OnStart(Box), /// Runs when the agent stops OnStop(Box), /// Runs when the agent streams a response OnStream(Box), } impl BeforeAllFn for F where F: for<'a> Fn(&'a Agent) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl AfterEachFn for F where F: for<'a> Fn(&'a Agent) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl BeforeCompletionFn for F where F: for<'a> Fn( &'a Agent, &mut ChatCompletionRequest<'_>, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl AfterCompletionFn for F where F: for<'a> Fn( &'a Agent, &mut ChatCompletionResponse, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl BeforeToolFn for F where F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl AfterToolFn for F where F: for<'tool> Fn( &'tool Agent, &ToolCall, &'tool mut Result, ) -> Pin> + Send + 'tool>> + Send + Sync + DynClone { } impl MessageHookFn for F where F: for<'a> Fn( &'a Agent, &mut ChatMessage, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl OnStartFn for F where F: for<'a> Fn(&'a Agent) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl OnStopFn for F where F: for<'a> Fn( &'a Agent, StopReason, Option<&AgentError>, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } impl OnStreamFn for F where F: for<'a> Fn( &'a Agent, &ChatCompletionResponse, ) -> Pin> + Send + 'a>> + Send + Sync + DynClone { } #[cfg(test)] mod tests { use crate::Agent; #[test] fn test_hooks_compile_sync_and_async() { Agent::builder() .before_all(|_| Box::pin(async { Ok(()) })) .on_start(|_| Box::pin(async { Ok(()) })) .before_completion(|_, _| Box::pin(async { Ok(()) })) .before_tool(|_, _| Box::pin(async { Ok(()) })) .after_tool(|_, _, _| Box::pin(async { Ok(()) })) .after_completion(|_, _| Box::pin(async { Ok(()) })); } } ================================================ FILE: swiftide-agents/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] //! Swiftide agents are a flexible way to build fast and reliable AI agents. //! //! # Features //! //! * **Tools**: Tools can be defined as functions using the `#[tool]` attribute macro, the `Tool` //! derive macro, or manually implementing the `Tool` trait. //! * **Hooks**: At various stages of the agent lifecycle, hooks can be defined to run custom logic. //! These are defined when building the agent, and each take a closure. //! * **Context**: Agents operate in an `AgentContext`, which is a shared state between tools and //! hooks. The context is responsible for managing the completions and interacting with the //! outside world. //! * **Tool Execution**: A context takes a tool executor (local by default) to execute its tools //! on. This enables tools to be run i.e. in containers, remote, etc. //! * **System prompt defaults**: `SystemPrompt` provides a default, customizable prompt for the //! agent. If you want to provider your own prompt, the builder takes anything that converts into //! a `Prompt`, including strings. //! * **Open Telemetry**: Agents are fully instrumented with open telemetry. //! //! # Example //! //! ```ignore //! # use swiftide_agents::Agent; //! # use swiftide_integrations as integrations; //! # async fn run() -> Result<(), Box> { //! let openai = integrations::openai::OpenAI::builder() //! .default_prompt_model("gpt-4o-mini") //! .build()?; //! //! Agent::builder() //! .llm(&openai) //! .before_completion(move |_,_| //! Box::pin(async move { //! println!("Before each tool"); //! Ok(()) //! }) //! ) //! .build()? //! .query("What is the meaning of life?") //! .await?; //! # return Ok(()); //! //! # } //! ``` //! //! Agents run in a loop as long as they have new messages to process. mod agent; mod default_context; pub mod errors; pub mod hooks; mod state; pub mod system_prompt; pub mod tasks; pub mod tools; mod util; pub use agent::{Agent, AgentBuilder, AgentBuilderError}; pub use default_context::DefaultContext; pub use state::{State, StopReason}; #[cfg(any(test, debug_assertions))] pub mod test_utils; ================================================ FILE: swiftide-agents/src/snapshots/swiftide_agents__system_prompt__tests__customization.snap ================================================ --- source: swiftide-agents/src/system_prompt.rs expression: rendered --- # Your role special role # Guidelines you need to follow - Try to understand how to complete the task well before completing it. - special guideline # Constraints that must be adhered to - Think step by step - Think before you act; respond with your thoughts before calling tools - Do not make up any assumptions, use tools to get the information you need - Use the provided tools to interact with the system and accomplish the task - If you are stuck, or otherwise cannot complete the task, respond with your thoughts and call `stop`. - If the task is completed, or otherwise cannot continue, like requiring user feedback, call `stop`. - special constraint # Response Format - Always respond with your thoughts and reasoning for your actions in one or two sentences. Even when calling tools. - Once the goal is achieved, call the `stop` tool some additional info ================================================ FILE: swiftide-agents/src/snapshots/swiftide_agents__system_prompt__tests__to_prompt.snap ================================================ --- source: swiftide-agents/src/system_prompt.rs expression: rendered --- # Your role special role # Guidelines you need to follow - Try to understand how to complete the task well before completing it. - special guideline # Constraints that must be adhered to - Think step by step - Think before you act; respond with your thoughts before calling tools - Do not make up any assumptions, use tools to get the information you need - Use the provided tools to interact with the system and accomplish the task - If you are stuck, or otherwise cannot complete the task, respond with your thoughts and call `stop`. - If the task is completed, or otherwise cannot continue, like requiring user feedback, call `stop`. - special constraint # Response Format - Always respond with your thoughts and reasoning for your actions in one or two sentences. Even when calling tools. - Once the goal is achieved, call the `stop` tool some additional info ================================================ FILE: swiftide-agents/src/state.rs ================================================ //! Internal state of the agent use serde::{Deserialize, Serialize}; use serde_json::Value; use swiftide_core::chat_completion::ToolCall; #[derive(Clone, Debug, Default, strum_macros::EnumDiscriminants, strum_macros::EnumIs)] pub enum State { #[default] Pending, Running, Stopped(StopReason), } impl State { pub fn stop_reason(&self) -> Option<&StopReason> { match self { State::Stopped(reason) => Some(reason), _ => None, } } } /// The reason the agent stopped /// /// `StopReason::Other` has some convenience methods to convert from any `AsRef` #[non_exhaustive] #[derive(Clone, Debug, strum_macros::EnumIs, PartialEq, Serialize, Deserialize)] pub enum StopReason { /// A tool called stop RequestedByTool(ToolCall, Option), /// Agent failed to complete with optional message AgentFailed(Option), /// A tool repeatedly failed ToolCallsOverLimit(ToolCall), /// A tool requires feedback before it will continue FeedbackRequired { tool_call: ToolCall, payload: Option, }, /// There was an error Error, /// No new messages; stopping completions NoNewMessages, Other(String), } impl StopReason { pub fn as_requested_by_tool(&self) -> Option<(&ToolCall, Option<&Value>)> { if let StopReason::RequestedByTool(t, message) = self { Some((t, message.as_ref())) } else { None } } pub fn as_tool_calls_over_limit(&self) -> Option<&ToolCall> { if let StopReason::ToolCallsOverLimit(t) = self { Some(t) } else { None } } pub fn as_feedback_required(&self) -> Option<(&ToolCall, Option<&serde_json::Value>)> { if let StopReason::FeedbackRequired { tool_call, payload } = self { Some((tool_call, payload.as_ref())) } else { None } } pub fn as_error(&self) -> Option<()> { if matches!(self, StopReason::Error) { Some(()) } else { None } } pub fn as_no_new_messages(&self) -> Option<()> { if matches!(self, StopReason::NoNewMessages) { Some(()) } else { None } } pub fn as_other(&self) -> Option<&str> { if let StopReason::Other(s) = self { Some(s) } else { None } } } impl Default for StopReason { fn default() -> Self { StopReason::Other("No reason provided".into()) } } impl> From for StopReason { fn from(value: S) -> Self { StopReason::Other(value.as_ref().to_string()) } } ================================================ FILE: swiftide-agents/src/system_prompt.rs ================================================ //! The system prompt is the initial role and constraint defining message the LLM will receive for //! completion. //! //! By default, the system prompt is setup as a general-purpose chain-of-thought reasoning prompt //! with the role, guidelines, and constraints left empty for customization. //! //! You can override the the template entirely by providing your own `Prompt`. Optionally, you can //! still use the builder values by referencing them in your template. //! //! The builder provides an accessible way to build a system prompt. //! //! The agent will convert the system prompt into a prompt, adding it to the messages list the //! first time it is called. //! //! For customization, either the builder can be used to profit from defaults, or an override can //! be provided on the agent level. use derive_builder::Builder; use swiftide_core::prompt::Prompt; #[derive(Clone, Debug, Builder)] #[builder(setter(into, strip_option))] pub struct SystemPrompt { /// The role the agent is expected to fulfil. #[builder(default)] role: Option, /// Additional guidelines for the agent to follow #[builder(default, setter(custom))] guidelines: Vec, /// Additional constraints #[builder(default, setter(custom))] constraints: Vec, /// Optional additional raw markdown to append to the prompt /// /// For instance, if you would like to support an AGENTS.md file, add it here. #[builder(default)] additional: Option, /// The template to use for the system prompt #[builder(default = default_prompt_template())] template: Prompt, } impl SystemPrompt { pub fn builder() -> SystemPromptBuilder { SystemPromptBuilder::default() } pub fn to_prompt(&self) -> Prompt { self.clone().into() } /// Adds a guideline to the guidelines list. pub fn with_added_guideline(&mut self, guideline: impl AsRef) -> &mut Self { self.guidelines.push(guideline.as_ref().to_string()); self } /// Adds a constraint to the constraints list. pub fn with_added_constraint(&mut self, constraint: impl AsRef) -> &mut Self { self.constraints.push(constraint.as_ref().to_string()); self } /// Overwrites all guidelines. pub fn with_guidelines, S: AsRef>( &mut self, guidelines: T, ) -> &mut Self { self.guidelines = guidelines .into_iter() .map(|s| s.as_ref().to_string()) .collect(); self } /// Overwrites all constraints. pub fn with_constraints, S: AsRef>( &mut self, constraints: T, ) -> &mut Self { self.constraints = constraints .into_iter() .map(|s| s.as_ref().to_string()) .collect(); self } /// Changes the role. pub fn with_role(&mut self, role: impl Into) -> &mut Self { self.role = Some(role.into()); self } /// Sets the additional markdown field. pub fn with_additional(&mut self, additional: impl Into) -> &mut Self { self.additional = Some(additional.into()); self } /// Sets the template. pub fn with_template(&mut self, template: impl Into) -> &mut Self { self.template = template.into(); self } } impl From for SystemPrompt { fn from(text: String) -> Self { SystemPrompt { role: None, guidelines: Vec::new(), constraints: Vec::new(), additional: None, template: text.into(), } } } impl From<&'static str> for SystemPrompt { fn from(text: &'static str) -> Self { SystemPrompt { role: None, guidelines: Vec::new(), constraints: Vec::new(), additional: None, template: text.into(), } } } impl From for SystemPromptBuilder { fn from(val: SystemPrompt) -> Self { SystemPromptBuilder { role: Some(val.role), guidelines: Some(val.guidelines), constraints: Some(val.constraints), additional: Some(val.additional), template: Some(val.template), } } } impl From for SystemPrompt { fn from(prompt: Prompt) -> Self { SystemPrompt { role: None, guidelines: Vec::new(), constraints: Vec::new(), additional: None, template: prompt, } } } impl Default for SystemPrompt { fn default() -> Self { SystemPrompt { role: None, guidelines: Vec::new(), constraints: Vec::new(), additional: None, template: default_prompt_template(), } } } impl SystemPromptBuilder { pub fn add_guideline(&mut self, guideline: &str) -> &mut Self { self.guidelines .get_or_insert_with(Vec::new) .push(guideline.to_string()); self } pub fn add_constraint(&mut self, constraint: &str) -> &mut Self { self.constraints .get_or_insert_with(Vec::new) .push(constraint.to_string()); self } pub fn guidelines, S: AsRef>( &mut self, guidelines: T, ) -> &mut Self { self.guidelines = Some( guidelines .into_iter() .map(|s| s.as_ref().to_string()) .collect(), ); self } pub fn constraints, S: AsRef>( &mut self, constraints: T, ) -> &mut Self { self.constraints = Some( constraints .into_iter() .map(|s| s.as_ref().to_string()) .collect(), ); self } } fn default_prompt_template() -> Prompt { include_str!("system_prompt_template.md").into() } #[allow(clippy::from_over_into)] impl Into for SystemPrompt { fn into(self) -> Prompt { let SystemPrompt { role, guidelines, constraints, template, additional, } = self; template .with_context_value("role", role) .with_context_value("guidelines", guidelines) .with_context_value("constraints", constraints) .with_context_value("additional", additional) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_customization() { let prompt = SystemPrompt::builder() .role("special role") .guidelines(["special guideline"]) .constraints(vec!["special constraint".to_string()]) .additional("some additional info") .build() .unwrap(); let prompt: Prompt = prompt.into(); let rendered = prompt.render().unwrap(); assert!(rendered.contains("special role"), "error: {rendered}"); assert!(rendered.contains("special guideline"), "error: {rendered}"); assert!(rendered.contains("special constraint"), "error: {rendered}"); assert!( rendered.contains("some additional info"), "error: {rendered}" ); insta::assert_snapshot!(rendered); } #[tokio::test] async fn test_to_prompt() { let prompt = SystemPrompt::builder() .role("special role") .guidelines(["special guideline"]) .constraints(vec!["special constraint".to_string()]) .additional("some additional info") .build() .unwrap(); let prompt: Prompt = prompt.to_prompt(); let rendered = prompt.render().unwrap(); assert!(rendered.contains("special role"), "error: {rendered}"); assert!(rendered.contains("special guideline"), "error: {rendered}"); assert!(rendered.contains("special constraint"), "error: {rendered}"); assert!( rendered.contains("some additional info"), "error: {rendered}" ); insta::assert_snapshot!(rendered); } #[tokio::test] async fn test_system_prompt_to_builder() { let sp = SystemPrompt { role: Some("Assistant".to_string()), guidelines: vec!["Be concise".to_string()], constraints: vec!["No personal opinions".to_string()], additional: None, template: "Hello, {{role}}! Guidelines: {{guidelines}}, Constraints: {{constraints}}" .into(), }; let builder = SystemPromptBuilder::from(sp.clone()); assert_eq!(builder.role, Some(Some("Assistant".to_string()))); assert_eq!(builder.guidelines, Some(vec!["Be concise".to_string()])); assert_eq!( builder.constraints, Some(vec!["No personal opinions".to_string()]) ); // For template, compare the rendered string assert_eq!( builder.template.as_ref().unwrap().render().unwrap(), sp.template.render().unwrap() ); } #[test] fn test_with_added_guideline_and_constraint() { let mut sp = SystemPrompt::default(); sp.with_added_guideline("Stay polite") .with_added_guideline("Use Markdown") .with_added_constraint("No personal info") .with_added_constraint("Short responses"); assert_eq!(sp.guidelines, vec!["Stay polite", "Use Markdown"]); assert_eq!(sp.constraints, vec!["No personal info", "Short responses"]); } #[test] fn test_with_guidelines_and_constraints_overwrites() { let mut sp = SystemPrompt::default(); sp.with_guidelines(["A", "B", "C"]) .with_constraints(vec!["X", "Y"]); assert_eq!(sp.guidelines, vec!["A", "B", "C"]); assert_eq!(sp.constraints, vec!["X", "Y"]); // Overwrite with different contents sp.with_guidelines(vec!["Z"]); sp.with_constraints(["P", "Q"]); assert_eq!(sp.guidelines, vec!["Z"]); assert_eq!(sp.constraints, vec!["P", "Q"]); } #[test] fn test_with_role_and_additional_and_template() { let mut sp = SystemPrompt::default(); sp.with_role("explainer") .with_additional("AGENTS.md here") .with_template("Template: {{role}}"); assert_eq!(sp.role.as_deref(), Some("explainer")); assert_eq!(sp.additional.as_deref(), Some("AGENTS.md here")); assert_eq!(sp.template.render().unwrap(), "Template: {{role}}"); } } ================================================ FILE: swiftide-agents/src/system_prompt_template.md ================================================ {% if role -%} # Your role {{role}} {% endif -%} # Guidelines you need to follow {# Guidelines provide soft rules and best practices to complete a task well -#} - Try to understand how to complete the task well before completing it. {% for item in guidelines -%} - {{item}} {% endfor %} # Constraints that must be adhered to {# Constraints are hard limitations that an agent must follow -#} - Think step by step - Think before you act; respond with your thoughts before calling tools - Do not make up any assumptions, use tools to get the information you need - Use the provided tools to interact with the system and accomplish the task - If you are stuck, or otherwise cannot complete the task, respond with your thoughts and call `stop`. - If the task is completed, or otherwise cannot continue, like requiring user feedback, call `stop`. {% for item in constraints -%} - {{item}} {% endfor %} # Response Format {# Instruct the agent to always respond with their thoughts (chain-of-thought) -#} - Always respond with your thoughts and reasoning for your actions in one or two sentences. Even when calling tools. - Once the goal is achieved, call the `stop` tool {{additional}} ================================================ FILE: swiftide-agents/src/tasks/closures.rs ================================================ use std::pin::Pin; use async_trait::async_trait; use super::{ errors::NodeError, node::{NodeArg, NodeId, TaskNode}, }; #[derive(Clone)] pub struct SyncFn where F: Fn(&I) -> Result + Send + Sync + Clone + 'static, { pub f: F, _phantom: std::marker::PhantomData<(I, O)>, } #[derive(Clone)] pub struct AsyncFn where F: for<'a> Fn(&'a I) -> Pin> + Send + 'a>> + Send + Sync + Clone + 'static, { pub f: F, _phantom: std::marker::PhantomData<(I, O)>, } impl SyncFn where F: Fn(&I) -> Result + Send + Sync + Clone + 'static, I: NodeArg + Clone, O: NodeArg + Clone, { pub fn new(f: F) -> Self { SyncFn { f, _phantom: std::marker::PhantomData, } } } impl AsyncFn where F: for<'a> Fn(&'a I) -> Pin> + Send + 'a>> + Send + Sync + Clone + 'static, I: NodeArg + Clone, O: NodeArg + Clone, { pub fn new(f: F) -> Self { AsyncFn { f, _phantom: std::marker::PhantomData, } } } impl From for SyncFn where F: Fn(&()) -> Result<(), NodeError> + Send + Sync + Clone + 'static, { fn from(f: F) -> Self { SyncFn::new(f) } } impl From for AsyncFn where F: for<'a> Fn(&'a ()) -> Pin> + Send + 'a>> + Send + Sync + Clone + 'static, { fn from(f: F) -> Self { AsyncFn::new(f) } } #[async_trait] impl TaskNode for SyncFn where F: Fn(&I) -> Result + Clone + Send + Sync + 'static, I: NodeArg + Clone, O: NodeArg + Clone, { type Input = I; type Output = O; type Error = NodeError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { (self.f)(input) } } #[async_trait] impl TaskNode for AsyncFn where F: for<'a> Fn(&'a I) -> Pin> + Send + 'a>> + Clone + Send + Sync + 'static, I: NodeArg + Clone, O: NodeArg + Clone, { type Input = I; type Output = O; type Error = NodeError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { (self.f)(input).await } } ================================================ FILE: swiftide-agents/src/tasks/errors.rs ================================================ use std::{any::Any, sync::Arc}; use super::transition::TransitionPayload; #[derive(thiserror::Error, Debug)] pub enum TaskError { #[error(transparent)] NodeError(#[from] NodeError), #[error("MissingTransition: {0}")] MissingTransition(String), #[error("MissingNode: {0}")] MissingNode(String), #[error("Task failed with wrong output")] TypeError(String), #[error("MissingInput: {0}")] MissingInput(String), #[error("MissingOutput: {0}")] MissingOutput(String), #[error("Task is missing steps")] NoSteps, } impl TaskError { pub fn missing_transition(node_id: usize) -> Self { TaskError::MissingTransition(format!("Node {node_id} is missing a transition")) } pub fn missing_node(node_id: usize) -> Self { TaskError::MissingNode(format!("Node {node_id} is missing")) } pub fn missing_input(node_id: usize) -> Self { TaskError::MissingInput(format!("Node {node_id} is missing input")) } pub fn missing_output(node_id: usize) -> Self { TaskError::MissingOutput(format!("Node {node_id} is missing output")) } pub fn type_error(output: &T) -> Self { let message = format!( "Expected output of type {}, but got {:?}", std::any::type_name::(), output.type_id() ); TaskError::TypeError(message) } } #[derive(Debug, thiserror::Error)] pub struct NodeError { pub node_error: Box, pub transition_payload: Option>, pub node_id: usize, } impl std::fmt::Display for NodeError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Node error in node {}: {:?}", self.node_id, self.node_error ) } } impl NodeError { pub fn new( node_error: impl Into>, node_id: usize, transition_payload: Option, ) -> Self { Self { node_error: node_error.into(), transition_payload: transition_payload.map(Arc::new), node_id, } } } ================================================ FILE: swiftide-agents/src/tasks/impls.rs ================================================ use std::sync::Arc; use async_trait::async_trait; use swiftide_core::{ ChatCompletion, Command, CommandError, CommandOutput, SimplePrompt, ToolExecutor, chat_completion::{ChatCompletionRequest, ChatCompletionResponse, errors::LanguageModelError}, prompt::Prompt, }; use tokio::sync::Mutex; use crate::{Agent, errors::AgentError}; use super::node::{NodeArg, NodeId, TaskNode}; /// An example of wrapping an Agent as a `TaskNode` /// /// For more control you can always roll your own #[derive(Clone, Debug)] pub struct TaskAgent(Arc>); impl From for TaskAgent { fn from(agent: Agent) -> Self { TaskAgent(Arc::new(Mutex::new(agent))) } } /// A 'default' implementation for an agent where there is no output #[async_trait] impl TaskNode for TaskAgent { type Input = Prompt; type Output = (); type Error = AgentError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { self.0.lock().await.query(input.clone()).await } } #[async_trait] impl TaskNode for Box { type Input = Prompt; type Output = String; type Error = LanguageModelError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { // TODO: Prompt should be borrowed self.prompt(input.clone()).await } } #[async_trait] impl TaskNode for Arc { type Input = Prompt; type Output = String; type Error = LanguageModelError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { // TODO: Prompt should be borrowed self.prompt(input.clone()).await } } #[async_trait] impl TaskNode for Box { type Input = ChatCompletionRequest<'static>; type Output = ChatCompletionResponse; type Error = LanguageModelError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { self.complete(input).await } } #[async_trait] impl TaskNode for Arc { type Input = ChatCompletionRequest<'static>; type Output = ChatCompletionResponse; type Error = LanguageModelError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { self.complete(input).await } } #[async_trait] impl TaskNode for Box { type Input = Command; type Output = CommandOutput; type Error = CommandError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { self.exec_cmd(input).await } } #[async_trait] impl TaskNode for Arc { type Input = Command; type Output = CommandOutput; type Error = CommandError; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { self.exec_cmd(input).await } } // Note: This only works for function pointers, not closures. #[async_trait] impl TaskNode for fn(&I) -> Result { type Input = I; type Output = O; type Error = E; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { (self)(input) } } ================================================ FILE: swiftide-agents/src/tasks/mod.rs ================================================ pub mod closures; pub mod errors; pub mod impls; pub mod node; pub mod task; pub mod transition; ================================================ FILE: swiftide-agents/src/tasks/node.rs ================================================ use std::any::Any; use async_trait::async_trait; use dyn_clone::DynClone; use super::{ errors::NodeError, transition::{MarkedTransitionPayload, TransitionPayload}, }; pub trait NodeArg: Send + Sync + DynClone + 'static {} impl NodeArg for T {} #[derive(Debug, Clone)] pub struct NoopNode { _marker: std::marker::PhantomData<(Context, Box)>, } impl Default for NoopNode where Context: NodeArg, { fn default() -> Self { NoopNode { _marker: std::marker::PhantomData, } } } #[async_trait] impl TaskNode for NoopNode { type Output = (); type Input = Context; type Error = NodeError; async fn evaluate( &self, _node_id: &DynNodeId, _context: &Context, ) -> Result { Ok(()) } } #[async_trait] pub trait TaskNode: Send + Sync + DynClone + Any { type Input: NodeArg; type Output: NodeArg; type Error: std::error::Error + Send + Sync + 'static; async fn evaluate( &self, node_id: &DynNodeId, input: &Self::Input, ) -> Result; } pub type DynNodeId = NodeId< dyn TaskNode< Input = ::Input, Output = ::Output, Error = ::Error, >, >; dyn_clone::clone_trait_object!( TaskNode< Input = dyn NodeArg, Output = dyn NodeArg, Error = dyn std::error::Error + Send + Sync, > ); #[async_trait] impl TaskNode for Box> { type Input = Input; type Output = Output; type Error = Error; async fn evaluate( &self, node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { self.as_ref().evaluate(node_id, input).await } } dyn_clone::clone_trait_object!( TaskNode); #[derive(PartialEq, Eq)] pub struct NodeId { pub id: usize, _marker: std::marker::PhantomData, } impl std::fmt::Debug for NodeId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let type_name = std::any::type_name::(); write!(f, "NodeId<{type_name}>({})", self.id) } } pub type AnyNodeId = usize; impl NodeId { pub fn id(&self) -> usize { self.id } /// Returns a closure that can be used as a transition function pub fn as_transition(&self) -> impl Fn(T::Input) -> MarkedTransitionPayload + 'static { let node_id = *self; Box::new(move |context| node_id.transitions_with(context)) } /// Returns a transition payload suitable for inside a task transition /// /// You can also get the closure version with `as_transition` pub fn transitions_with(&self, context: T::Input) -> MarkedTransitionPayload { MarkedTransitionPayload::new(TransitionPayload::next_node(self, context)) } } impl NodeId { pub fn new(id: usize, _node: &T) -> Self { NodeId { id, _marker: std::marker::PhantomData, } } /// Returns the internal id of the node without the type information. pub fn as_any(&self) -> AnyNodeId { self.id } pub fn as_dyn( self, ) -> NodeId> { NodeId { id: self.id, _marker: std::marker::PhantomData, } } } impl Clone for NodeId { fn clone(&self) -> Self { *self } } impl Copy for NodeId {} ================================================ FILE: swiftide-agents/src/tasks/task.rs ================================================ //! Tasks enable you to to define a graph of interacting nodes //! //! The nodes can be any type that implements the `TaskNode` trait, which defines how the node //! will be evaluated with its input and output. //! //! Most swiftide primitives implement `TaskNode`, and it's easy to implement your own. Since how //! agents interact is subject to taste, we recommend implementing your own. //! //! WARN: Here be dragons! This api is not stable yet. We are using it in production, and is //! subject to rapid change. However, do not hesitate to open an issue if you find anything. use std::{any::Any, pin::Pin, sync::Arc}; use tracing::Instrument as _; use crate::tasks::{errors::NodeError, transition::TransitionFn}; use super::{ errors::TaskError, node::{NodeArg, NodeId, NoopNode, TaskNode}, transition::{AnyNodeTransition, MarkedTransitionPayload, Transition, TransitionPayload}, }; #[derive(Debug)] pub struct Task { nodes: Vec>, current_node: usize, start_node: usize, current_context: Option>, _marker: std::marker::PhantomData<(Input, Output)>, } impl Clone for Task { fn clone(&self) -> Self { Self { nodes: self.nodes.clone(), current_node: 0, start_node: self.start_node, current_context: None, _marker: std::marker::PhantomData, } } } impl Default for Task { fn default() -> Self { Self::new() } } impl Task { pub fn new() -> Self { let noop = NoopNode::::default(); let node_id = NodeId::new(0, &noop).as_dyn(); let noop_executor = Box::new(Transition { node: Box::new(noop), node_id: Box::new(node_id), r#fn: Arc::new(|_output| { Box::pin(async { unreachable!("Done node should never be evaluated.") }) }), is_set: false, }); Self { nodes: vec![noop_executor], current_node: 0, start_node: 0, current_context: None, _marker: std::marker::PhantomData, } } /// Returns the current context as the input type, if it matches pub fn current_input(&self) -> Option<&Input> { let input = self.current_context.as_ref()?; input.downcast_ref::() } /// Returns the current context as the output type, if it matches pub fn current_output(&self) -> Option<&Output> { let input = self.current_context.as_ref()?; input.downcast_ref::() } /// Returns the `done` node for this task pub fn done(&self) -> NodeId> { NodeId::new(0, &NoopNode::default()) } /// Creates a transition to the done node pub fn transitions_to_done( &self, ) -> impl Fn(Output) -> MarkedTransitionPayload> + Send + Sync + 'static { let done = self.done(); move |context| done.transitions_with(context) } /// Defines the start node of the task pub fn starts_with + Clone + 'static>( &mut self, node_id: NodeId, ) { self.current_node = node_id.id; self.start_node = node_id.id; } /// Validates that all nodes have transitions set /// /// # Errors /// /// Errors if a node is missing a transition pub fn validate_transitions(&self) -> Result<(), TaskError> { // TODO: Validate that the task can complete for node_executor in &self.nodes { // Skip the done node (index 0) if node_executor.node_id() == 0 { continue; } if !node_executor.transition_is_set() { return Err(TaskError::missing_transition(node_executor.node_id())); } } Ok(()) } /// Runs the task with the given input /// /// # Errors /// /// Errors if the task fails #[tracing::instrument(skip(self, input), name = "task.run", err)] pub async fn run(&mut self, input: impl Into) -> Result, TaskError> { self.validate_transitions()?; self.current_context = Some(Arc::new(input.into()) as Arc); self.start_task().await } /// Resets the task to the start node /// /// WARN: This **will** lead to a type mismatch if the previous context is not the same as the /// input of the start node pub fn reset(&mut self) { self.current_node = self.start_node; } /// Resumes the task from the current node /// /// # Errors /// /// Errors if the task fails #[tracing::instrument(skip(self), name = "task.resume", err)] pub async fn resume(&mut self) -> Result, TaskError> { self.start_task().await } async fn start_task(&mut self) -> Result, TaskError> { self.validate_transitions()?; let mut span = tracing::info_span!("task.step", node = self.current_node); loop { if self.current_node == 0 { break; } let node_transition = self .nodes .get(self.current_node) .ok_or_else(|| TaskError::missing_node(self.current_node))?; let input = self .current_context .clone() .ok_or_else(|| TaskError::missing_input(self.current_node))?; tracing::debug!("Running node {}", self.current_node); let span_id = span.id().clone(); let transition_payload = node_transition .evaluate_next(input) .instrument(span.or_current()) .await?; match transition_payload { TransitionPayload::Pause => { tracing::info!("Task paused at node {}", self.current_node); return Ok(None); } TransitionPayload::NextNode(transition_payload) => { self.current_node = transition_payload.node_id; self.current_context = Some(transition_payload.context); } TransitionPayload::Error(error) => { return Err(TaskError::NodeError(NodeError::new( error, self.current_node, None, ))); } } if self.current_node == 0 { tracing::debug!("Task completed at node {}", self.current_node); break; } span = tracing::info_span!("task.step", node = self.current_node).or_current(); span.follows_from(span_id); } let output = self .current_context .clone() .ok_or_else(|| TaskError::missing_output(self.current_node))?; let output = output .downcast::() .map_err(|e| TaskError::type_error(&e))? .as_ref() .clone(); Ok(Some(output)) } /// Gets the current node of the task pub fn current_node(&self) -> Option<&T> { self.node_at_index(self.current_node) } /// Gets the node at the given `NodeId` pub fn node_at(&self, node_id: NodeId) -> Option<&T> { self.node_at_index(node_id.id) } /// Gets the node at the given index pub fn node_at_index(&self, index: usize) -> Option<&T> { let transition = self.transition_at_index::(index)?; let node = &*transition.node; (node as &dyn Any).downcast_ref::() } /// Gets the current transition of the task #[allow(dead_code)] fn current_transition( &self, ) -> Option<&Transition> { self.transition_at_index::(self.current_node) } /// Gets the transition at the given `NodeId` fn transition_at_index( &self, index: usize, ) -> Option<&Transition> { tracing::debug!("Getting transition at index {}", index); let transition = self.nodes.get(index)?; dbg!(&transition); (&**transition as &dyn Any).downcast_ref::>() } /// Registers a new node in the task pub fn register_node(&mut self, node: T) -> NodeId where T: TaskNode + 'static + Clone, ::Input: Clone, ::Output: Clone, { let id = self.nodes.len(); let node_id = NodeId::new(id, &node); let node_executor = Box::new(Transition:: { node_id: Box::new(node_id.as_dyn()), node: Box::new(node), r#fn: Arc::new(move |_output| unreachable!("No transition for node {}.", node_id.id)), is_set: false, }); // Debug the type name tracing::debug!(node_id = ?node_id, type_name = std::any::type_name_of_val(&node_executor), "Registering node"); self.nodes.push(node_executor); node_id } /// Registers a transition from one node to another /// /// Note that there are various helpers and conversions for the `MarkedTransitionPayload` /// /// # Errors /// /// Errors if the node does not exist pub fn register_transition<'a, From, To, F>( &mut self, from: NodeId, transition: F, ) -> Result<(), TaskError> where From: TaskNode + 'static + ?Sized, To: TaskNode + 'a + ?Sized, F: Fn(To::Input) -> MarkedTransitionPayload + Send + Sync + 'static, { let node_executor = self .nodes .get_mut(from.id) .ok_or_else(|| TaskError::missing_node(from.id))?; let any_executor: &mut dyn Any = node_executor.as_mut(); let Some(exec) = any_executor.downcast_mut::>() else { let expected = std::any::type_name::>(); let actual = std::any::type_name_of_val(node_executor); unreachable!( "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}", from.id ); }; let transition = Arc::new(transition); let wrapped: Arc> = Arc::new(move |output: From::Output| { let transition = transition.clone(); Box::pin(async move { let output = transition(output); output.into_inner() }) }); exec.r#fn = wrapped; exec.is_set = true; // set function as before Ok(()) } /// Registers a transition from one node to another asynchronously /// /// Note that there are various helpers and conversions for the `MarkedTransitionPayload` /// /// # Errors /// /// Errors if the node does not exist /// /// NOTE: `AsyncFn` traits' returned future are not 'Send' and the inner type is unstable. /// When they are, we can update Fn to `AsyncFn` pub fn register_transition_async<'a, From, To, F>( &mut self, from: NodeId, transition: F, ) -> Result<(), TaskError> where From: TaskNode + 'static + ?Sized, To: TaskNode + 'a + ?Sized, F: Fn(To::Input) -> Pin> + Send>> + Send + Sync + 'static, { let node_executor = self .nodes .get_mut(from.id) .ok_or_else(|| TaskError::missing_node(from.id))?; let any_executor: &mut dyn Any = node_executor.as_mut(); let Some(exec) = any_executor.downcast_mut::>() else { let expected = std::any::type_name::>(); let actual = std::any::type_name_of_val(node_executor); unreachable!( "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}", from.id ); }; let transition = Arc::new(transition); let wrapped: Arc> = Arc::new(move |output: From::Output| { let transition = transition.clone(); Box::pin(async move { let output = transition(output).await; output.into_inner() }) }); exec.r#fn = wrapped; exec.is_set = true; // set function as before Ok(()) } } #[cfg(test)] mod tests { use async_trait::async_trait; use super::*; #[derive(thiserror::Error, Debug)] struct Error(String); impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } #[derive(Clone, Default, Debug)] struct IntNode; #[async_trait] impl TaskNode for IntNode { type Input = i32; type Output = i32; type Error = Error; async fn evaluate( &self, _node_id: &NodeId< dyn TaskNode, >, input: &Self::Input, ) -> Result { Ok(input + 1) } } // Implement other required traits if necessary... #[test_log::test(tokio::test)] async fn sequential_3_node_task_reset_works() { let mut task: Task = Task::new(); // Register three nodes let node1 = task.register_node(IntNode); let node2 = task.register_node(IntNode); let node3 = task.register_node(IntNode); // Set start node task.starts_with(node1); // Register transitions (node1 → node2 → node3 → done) task.register_transition::<_, _, _>(node1, move |input| node2.transitions_with(input)) .unwrap(); task.register_transition::<_, _, _>(node2, move |input| node3.transitions_with(input)) .unwrap(); task.register_transition::<_, _, _>(node3, task.transitions_to_done()) .unwrap(); // Run the task to completion let res = task.run(1).await.unwrap(); assert_eq!(res, Some(4)); // 1 + 1 + 1 + 1 // Reset the task task.reset(); // Assert current_node returns the correct node (node1) dbg!(&task); let n1_transition = task.transition_at_index::(1); assert!(n1_transition.is_some()); let n1_transition = task.current_transition::(); assert!(n1_transition.is_some()); let n1_ref = task.current_node::(); assert!(n1_ref.is_some()); } } ================================================ FILE: swiftide-agents/src/tasks/transition.rs ================================================ use std::{any::Any, pin::Pin, sync::Arc}; use async_trait::async_trait; use dyn_clone::DynClone; use super::{ errors::NodeError, node::{NodeArg, NodeId, TaskNode}, }; pub trait TransitionFn: for<'a> Fn(Input) -> Pin + Send>> + Send + Sync { } // dyn_clone::clone_trait_object!( TransitionFn); impl TransitionFn for F where F: for<'a> Fn(Input) -> Pin + Send>> + Send + Sync { } pub(crate) struct Transition< Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static, > { pub(crate) node: Box + Send + Sync>, pub(crate) node_id: Box>>, // pub(crate) r#fn: Arc TransitionPayload + Send + Sync>, pub(crate) r#fn: Arc + Send>, pub(crate) is_set: bool, } impl Clone for Transition where Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static, { fn clone(&self) -> Self { Transition { node: self.node.clone(), node_id: self.node_id.clone(), r#fn: self.r#fn.clone(), is_set: self.is_set, } } } impl std::fmt::Debug for Transition { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Transition") .field("node_id", &self.node_id.id) .field("is_set", &self.is_set) .finish() } } #[derive(Debug, Clone)] pub struct NextNode { // If we make this an enum instead, we can support spawning many nodes as well pub(crate) node_id: usize, pub(crate) context: Arc, } impl NextNode { pub fn new(node_id: NodeId, context: T::Input) -> Self where ::Input: 'static, { let context = Arc::new(context) as Arc; NextNode { node_id: node_id.id, context, } } } impl From for TransitionPayload { fn from(next_node: NextNode) -> Self { TransitionPayload::NextNode(next_node) } } #[derive(Debug)] pub enum TransitionPayload { NextNode(NextNode), Pause, Error(Box), } impl TransitionPayload { pub fn next_node(node_id: &NodeId, context: T::Input) -> Self { NextNode::new(*node_id, context).into() } pub fn pause() -> Self { TransitionPayload::Pause } pub fn error(error: impl Into>) -> Self { TransitionPayload::Error(error.into()) } } pub struct MarkedTransitionPayload( TransitionPayload, std::marker::PhantomData, ); impl MarkedTransitionPayload { pub fn new(payload: TransitionPayload) -> Self { MarkedTransitionPayload(payload, std::marker::PhantomData) } pub fn into_inner(self) -> TransitionPayload { self.0 } } impl std::ops::Deref for MarkedTransitionPayload { type Target = TransitionPayload; fn deref(&self) -> &Self::Target { &self.0 } } #[async_trait] pub(crate) trait AnyNodeTransition: Any + Send + Sync + std::fmt::Debug + DynClone { fn transition_is_set(&self) -> bool; async fn evaluate_next( &self, context: Arc, ) -> Result; fn node_id(&self) -> usize; } dyn_clone::clone_trait_object!(AnyNodeTransition); #[async_trait] impl AnyNodeTransition for Transition { async fn evaluate_next( &self, context: Arc, ) -> Result { let context = context.downcast::().unwrap(); match self.node.evaluate(&self.node_id.as_dyn(), &context).await { Ok(output) => Ok((self.r#fn)(output).await), Err(error) => Err(NodeError::new(error, self.node_id.id, None)), /* node_id will be * set by caller */ } } fn transition_is_set(&self) -> bool { self.is_set } fn node_id(&self) -> usize { self.node_id.id } } ================================================ FILE: swiftide-agents/src/test_utils.rs ================================================ use std::borrow::Cow; use std::sync::{Arc, Mutex}; use async_trait::async_trait; use swiftide_core::chat_completion::ToolCall; use swiftide_core::chat_completion::{Tool, ToolOutput, ToolSpec, errors::ToolError}; use swiftide_core::AgentContext; use crate::Agent; use crate::hooks::{ AfterCompletionFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn, }; #[macro_export] macro_rules! chat_request { ($($message:expr),+; tools = [$($tool:expr),*]) => {{ let mut builder = swiftide_core::chat_completion::ChatCompletionRequest::builder(); builder.messages(vec![$($message),*]); let mut tool_specs = Vec::new(); $(tool_specs.push({ let tool = $tool; tool.tool_spec() });)* tool_specs.extend(Agent::default_tools().into_iter().map(|tool| tool.tool_spec())); builder.tool_specs(tool_specs); builder.build().unwrap() }}; ($($message:expr),+; tool_specs = [$($tool:expr),*]) => {{ let mut builder = swiftide_core::chat_completion::ChatCompletionRequest::builder(); builder.messages(vec![$($message),*]); let mut tool_specs = Vec::new(); $(tool_specs.push($tool);)* tool_specs.extend(Agent::default_tools().into_iter().map(|tool| tool.tool_spec())); builder.tool_specs(tool_specs); builder.build().unwrap() }} } #[macro_export] macro_rules! user { ($message:expr) => { swiftide_core::chat_completion::ChatMessage::new_user($message) }; } #[macro_export] macro_rules! system { ($message:expr) => { swiftide_core::chat_completion::ChatMessage::new_system($message) }; } #[macro_export] macro_rules! summary { ($message:expr) => { swiftide_core::chat_completion::ChatMessage::new_summary($message) }; } #[macro_export] macro_rules! assistant { ($message:expr) => { swiftide_core::chat_completion::ChatMessage::new_assistant( Some($message.to_string()), None, ) }; ($message:expr, [$($tool_call_name:expr),*]) => {{ let tool_calls = vec![ $( ToolCall::builder() .name($tool_call_name) .id("1") .build() .unwrap() ),* ]; ChatMessage::new_assistant(Some($message.to_string()), Some(tool_calls)) }}; } #[macro_export] macro_rules! tool_output { ($tool_name:expr, $message:expr) => {{ ChatMessage::ToolOutput( ToolCall::builder() .name($tool_name) .id("1") .build() .unwrap(), $message.into(), ) }}; } #[macro_export] macro_rules! tool_failed { ($tool_name:expr, $message:expr) => {{ ChatMessage::ToolOutput( ToolCall::builder() .name($tool_name) .id("1") .build() .unwrap(), ToolOutput::fail($message), ) }}; } #[macro_export] macro_rules! chat_response { ($message:expr; tool_calls = [$($tool_name:expr),*]) => {{ let tool_calls = vec![ $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),* ]; ChatCompletionResponse::builder() .message($message) .tool_calls(tool_calls) .build() .unwrap() }}; (tool_calls = [$($tool_name:expr),*]) => {{ let tool_calls = vec![ $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),* ]; ChatCompletionResponse::builder() .tool_calls(tool_calls) .build() .unwrap() }}; } type Expectations = Arc, Option<&'static str>)>>>; #[derive(Debug, Clone)] pub struct MockTool { expectations: Expectations, name: &'static str, } impl MockTool { #[allow(clippy::should_implement_trait)] pub fn default() -> Self { Self::new("mock_tool") } pub fn new(name: &'static str) -> Self { Self { expectations: Arc::new(Mutex::new(Vec::new())), name, } } pub fn expect_invoke_ok( &self, expected_result: ToolOutput, expected_args: Option<&'static str>, ) { self.expect_invoke(Ok(expected_result), expected_args); } #[allow(clippy::missing_panics_doc)] pub fn expect_invoke( &self, expected_result: Result, expected_args: Option<&'static str>, ) { self.expectations .lock() .unwrap() .push((expected_result, expected_args)); } } #[async_trait] impl Tool for MockTool { async fn invoke( &self, _agent_context: &dyn AgentContext, tool_call: &ToolCall, ) -> std::result::Result { tracing::debug!( "[MockTool] Invoked `{}` with args: {:?}", self.name, tool_call ); let expectation = self .expectations .lock() .unwrap() .pop() .unwrap_or_else(|| panic!("[MockTool] No expectations left for `{}`", self.name)); assert_eq!(expectation.1, tool_call.args()); expectation.0 } fn name(&self) -> Cow<'_, str> { self.name.into() } fn tool_spec(&self) -> ToolSpec { ToolSpec::builder() .name(self.name().as_ref()) .description("A fake tool for testing purposes") .build() .unwrap() } } impl From for Box { fn from(val: MockTool) -> Self { Box::new(val) as Box } } impl Drop for MockTool { fn drop(&mut self) { // Mock still borrowed elsewhere and expectations still be invoked if Arc::strong_count(&self.expectations) > 1 { return; } if self.expectations.lock().is_err() { return; } let name = self.name; if self.expectations.lock().unwrap().is_empty() { tracing::debug!("[MockTool] All expectations were met for `{name}`"); } else { panic!( "[MockTool] Not all expectations were met for `{name}: {:?}", *self.expectations.lock().unwrap() ); } } } #[derive(Debug, Clone)] pub struct MockHook { name: &'static str, called: Arc>, expected_calls: usize, } impl MockHook { pub fn new(name: &'static str) -> Self { Self { name, called: Arc::new(Mutex::new(0)), expected_calls: 0, } } pub fn expect_calls(&mut self, expected_calls: usize) -> &mut Self { self.expected_calls = expected_calls; self } #[allow(clippy::missing_panics_doc)] pub fn hook_fn(&self) -> impl BeforeAllFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn on_start_fn(&self) -> impl OnStartFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn before_completion_fn(&self) -> impl BeforeCompletionFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent, _| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn after_completion_fn(&self) -> impl AfterCompletionFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent, _| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn after_tool_fn(&self) -> impl AfterToolFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent, _, _| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn before_tool_fn(&self) -> impl BeforeToolFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent, _| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn message_hook_fn(&self) -> impl MessageHookFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent, _| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn stop_hook_fn(&self) -> impl OnStopFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent, _, _| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } #[allow(clippy::missing_panics_doc)] pub fn on_stream_fn(&self) -> impl OnStreamFn + use<> { let called = Arc::clone(&self.called); move |_: &Agent, _| { let called = Arc::clone(&called); Box::pin(async move { let mut called = called.lock().unwrap(); *called += 1; Ok(()) }) } } } impl Drop for MockHook { fn drop(&mut self) { if Arc::strong_count(&self.called) > 1 { return; } let Ok(called) = self.called.lock() else { return; }; if *called == self.expected_calls { tracing::debug!( "[MockHook] `{}` all expectations met; called {} times", self.name, *called ); } else { panic!( "[MockHook] `{}` was called {} times but expected {}", self.name, *called, self.expected_calls ) } } } ================================================ FILE: swiftide-agents/src/tools/arg_preprocessor.rs ================================================ use std::borrow::Cow; use serde_json::{Map, Value}; use swiftide_core::chat_completion::ToolCall; /// Preprocesses arguments for tool calls and tries to fix common errors /// This must be infallible and the result is always forwarded to the tool pub struct ArgPreprocessor; impl ArgPreprocessor { pub fn preprocess_tool_calls(tool_calls: &mut [ToolCall]) { for tool_call in tool_calls.iter_mut() { let args = Self::preprocess(tool_call.args()); if args.as_ref().is_some_and(|a| match a { Cow::Borrowed(_) => false, Cow::Owned(_) => true, }) { tool_call.with_args(args.map(|a| a.to_string())); } } } pub fn preprocess(value: Option<&str>) -> Option> { Some(take_first_occurrence_in_object(value?)) } } /// Strips duplicate keys from JSON objects fn take_first_occurrence_in_object(value: &str) -> Cow<'_, str> { let Ok(parsed) = &serde_json::from_str(value) else { return Cow::Borrowed(value); }; if let Value::Object(obj) = parsed { let mut new_map = Map::with_capacity(obj.len()); for (k, v) in obj { // Only insert if we haven't seen this key yet. new_map.entry(k).or_insert(v.clone()); } Cow::Owned(Value::Object(new_map).to_string()) } else { // If the top-level isn't even an object, just pass it as is, // or decide how you want to handle that situation. Cow::Borrowed(value) } } #[cfg(test)] mod tests { use super::*; use serde_json::json; #[test] fn test_preprocess_regular_json() { let input = json!({ "key1": "value1", "key2": "value2" }) .to_string(); let expected = json!({ "key1": "value1", "key2": "value2" }); let result = ArgPreprocessor::preprocess(Some(&input)); assert_eq!(result.as_deref(), Some(expected.to_string().as_str())); } #[test] fn test_preprocess_json_with_duplicate_keys() { let input = json!({ "key1": "value1", "key1": "value2" }) .to_string(); let expected = json!({ "key1": "value2" }); let result = ArgPreprocessor::preprocess(Some(&input)); assert_eq!(result.as_deref(), Some(expected.to_string().as_str())); } #[test] fn test_no_preprocess_invalid_json() { let input = "invalid json"; let result = ArgPreprocessor::preprocess(Some(input)); assert_eq!(result.as_deref(), Some(input)); } #[test] fn test_no_input() { let result = ArgPreprocessor::preprocess(None); assert_eq!(result, None); } } ================================================ FILE: swiftide-agents/src/tools/control.rs ================================================ //! Control tools manage control flow during agent's lifecycle. use anyhow::Result; use async_trait::async_trait; use schemars::{Schema, schema_for}; use std::borrow::Cow; use swiftide_core::{ AgentContext, ToolFeedback, chat_completion::{Tool, ToolCall, ToolOutput, ToolSpec, errors::ToolError}, }; /// `Stop` tool is a default tool used by agents to stop #[derive(Clone, Debug, Default)] pub struct Stop {} #[async_trait] impl Tool for Stop { async fn invoke( &self, _agent_context: &dyn AgentContext, _tool_call: &ToolCall, ) -> Result { Ok(ToolOutput::stop()) } fn name(&self) -> Cow<'_, str> { "stop".into() } fn tool_spec(&self) -> ToolSpec { ToolSpec::builder() .name("stop") .description("When you have completed, or cannot complete, your task, call this") .build() .unwrap() } } impl From for Box { fn from(val: Stop) -> Self { Box::new(val) } } /// `StopWithArgs` is an alternative stop tool that takes arguments #[derive(Clone, Debug)] pub struct StopWithArgs { parameters_schema: Option, expects_output_field: bool, } impl Default for StopWithArgs { fn default() -> Self { Self { parameters_schema: Some(schema_for!(DefaultStopWithArgsSpec)), expects_output_field: true, } } } impl StopWithArgs { /// Create a new `StopWithArgs` tool with a custom parameters schema. /// /// When providing a custom schema the full argument payload will be forwarded to the /// stop output without requiring an `output` field wrapper. pub fn with_parameters_schema(schema: Schema) -> Self { Self { parameters_schema: Some(schema), expects_output_field: false, } } fn parameters_schema(&self) -> Schema { self.parameters_schema .clone() .unwrap_or_else(|| schema_for!(DefaultStopWithArgsSpec)) } } #[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)] struct DefaultStopWithArgsSpec { pub output: String, } #[async_trait] impl Tool for StopWithArgs { async fn invoke( &self, _agent_context: &dyn AgentContext, tool_call: &ToolCall, ) -> Result { let raw_args = tool_call .args() .ok_or_else(|| ToolError::missing_arguments("arguments"))?; let json: serde_json::Value = serde_json::from_str(raw_args)?; let output = if self.expects_output_field { json.get("output") .cloned() .ok_or_else(|| ToolError::missing_arguments("output"))? } else { json }; Ok(ToolOutput::stop_with_args(output)) } fn name(&self) -> Cow<'_, str> { "stop".into() } fn tool_spec(&self) -> ToolSpec { let schema = self.parameters_schema(); ToolSpec::builder() .name("stop") .description("When you have completed, your task, call this with your expected output") .parameters_schema(schema) .build() .unwrap() } } impl From for Box { fn from(val: StopWithArgs) -> Self { Box::new(val) } } #[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)] struct AgentFailedArgsSpec { pub reason: String, } /// A utility tool that can be used to let an agent decide it failed /// /// This will _NOT_ have the agent return an error, instead, look at the stop reason of the agent. #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] pub struct AgentCanFail { parameters_schema: Option, expects_reason_field: bool, } impl Default for AgentCanFail { fn default() -> Self { Self { parameters_schema: Some(schema_for!(AgentFailedArgsSpec)), expects_reason_field: true, } } } impl AgentCanFail { /// Create a new `AgentCanFail` tool with a custom parameters schema. /// /// When providing a custom schema the full argument payload will be forwarded to the failure /// reason without requiring a `reason` field wrapper. pub fn with_parameters_schema(schema: Schema) -> Self { Self { parameters_schema: Some(schema), expects_reason_field: false, } } fn parameters_schema(&self) -> Schema { self.parameters_schema .clone() .unwrap_or_else(|| schema_for!(AgentFailedArgsSpec)) } } #[async_trait] impl Tool for AgentCanFail { async fn invoke( &self, _agent_context: &dyn AgentContext, tool_call: &ToolCall, ) -> Result { let raw_args = tool_call.args().ok_or_else(|| { if self.expects_reason_field { ToolError::missing_arguments("reason") } else { ToolError::missing_arguments("arguments") } })?; let reason = if self.expects_reason_field { let args: AgentFailedArgsSpec = serde_json::from_str(raw_args)?; args.reason } else { let json: serde_json::Value = serde_json::from_str(raw_args)?; json.to_string() }; Ok(ToolOutput::agent_failed(reason)) } fn name(&self) -> Cow<'_, str> { "task_failed".into() } fn tool_spec(&self) -> ToolSpec { let schema = self.parameters_schema(); ToolSpec::builder() .name("task_failed") .description("If you cannot complete your task, or have otherwise failed, call this with your reason for failure") .parameters_schema(schema) .build() .unwrap() } } impl From for Box { fn from(val: AgentCanFail) -> Self { Box::new(val) } } #[derive(Clone)] /// Wraps a tool and requires approval before it can be used pub struct ApprovalRequired(pub Box); impl ApprovalRequired { /// Creates a new `ApprovalRequired` tool pub fn new(tool: impl Tool + 'static) -> Self { Self(Box::new(tool)) } } #[async_trait] impl Tool for ApprovalRequired { async fn invoke( &self, context: &dyn AgentContext, tool_call: &ToolCall, ) -> Result { if let Some(feedback) = context.has_received_feedback(tool_call).await { match feedback { ToolFeedback::Approved { .. } => return self.0.invoke(context, tool_call).await, ToolFeedback::Refused { .. } => { return Ok(ToolOutput::text("This tool call was refused")); } } } Ok(ToolOutput::FeedbackRequired(None)) } fn name(&self) -> Cow<'_, str> { self.0.name() } fn tool_spec(&self) -> ToolSpec { self.0.tool_spec() } } impl From for Box { fn from(val: ApprovalRequired) -> Self { Box::new(val) } } #[cfg(test)] mod tests { use super::*; use schemars::schema_for; use serde_json::json; fn dummy_tool_call(name: &str, args: Option<&str>) -> ToolCall { let mut builder = ToolCall::builder().name(name).id("1").to_owned(); if let Some(args) = args { builder.args(args.to_string()); } builder.build().unwrap() } #[tokio::test] async fn test_stop_tool() { let stop = Stop::default(); let ctx = (); let tool_call = dummy_tool_call("stop", None); let out = stop.invoke(&ctx, &tool_call).await.unwrap(); assert_eq!(out, ToolOutput::stop()); } #[tokio::test] async fn test_stop_with_args_tool() { let tool = StopWithArgs::default(); let ctx = (); let args = r#"{"output":"expected result"}"#; let tool_call = dummy_tool_call("stop", Some(args)); let out = tool.invoke(&ctx, &tool_call).await.unwrap(); assert_eq!(out, ToolOutput::stop_with_args("expected result")); } #[tokio::test] async fn test_agent_can_fail_tool() { let tool = AgentCanFail::default(); let ctx = (); let args = r#"{"reason":"something went wrong"}"#; let tool_call = dummy_tool_call("task_failed", Some(args)); let out = tool.invoke(&ctx, &tool_call).await.unwrap(); assert_eq!(out, ToolOutput::agent_failed("something went wrong")); } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] struct CustomFailArgs { code: i32, message: String, } #[test] fn test_agent_can_fail_custom_schema_in_spec() { let schema = schema_for!(CustomFailArgs); let tool = AgentCanFail::with_parameters_schema(schema.clone()); let spec = tool.tool_spec(); assert_eq!(spec.parameters_schema, Some(schema)); } #[tokio::test] async fn test_agent_can_fail_custom_schema_forwards_payload() { let schema = schema_for!(CustomFailArgs); let tool = AgentCanFail::with_parameters_schema(schema); let ctx = (); let args = r#"{"code":7,"message":"error"}"#; let tool_call = dummy_tool_call("task_failed", Some(args)); let out = tool.invoke(&ctx, &tool_call).await.unwrap(); assert_eq!( out, ToolOutput::agent_failed(json!({"code":7,"message":"error"}).to_string()) ); } #[test] fn test_agent_can_fail_default_schema_matches_previous() { let tool = AgentCanFail::default(); let spec = tool.tool_spec(); let expected = schema_for!(AgentFailedArgsSpec); assert_eq!(spec.parameters_schema, Some(expected)); } #[tokio::test] async fn test_approval_required_feedback_required() { let stop = Stop::default(); let tool = ApprovalRequired::new(stop); let ctx = (); let tool_call = dummy_tool_call("stop", None); let out = tool.invoke(&ctx, &tool_call).await.unwrap(); // On unit; existing feedback is always present assert_eq!(out, ToolOutput::Stop(None)); } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] struct CustomStopArgs { value: i32, } #[test] fn test_stop_with_args_custom_schema_in_spec() { let schema = schema_for!(CustomStopArgs); let tool = StopWithArgs::with_parameters_schema(schema.clone()); let spec = tool.tool_spec(); assert_eq!(spec.parameters_schema, Some(schema)); } #[tokio::test] async fn test_stop_with_args_custom_schema_forwards_payload() { let schema = schema_for!(CustomStopArgs); let tool = StopWithArgs::with_parameters_schema(schema); let ctx = (); let args = r#"{"value":42}"#; let tool_call = dummy_tool_call("stop", Some(args)); let out = tool.invoke(&ctx, &tool_call).await.unwrap(); assert_eq!(out, ToolOutput::stop_with_args(json!({"value": 42}))); } #[test] fn test_stop_with_args_default_schema_matches_previous() { let tool = StopWithArgs::default(); let spec = tool.tool_spec(); let expected = schema_for!(DefaultStopWithArgsSpec); assert_eq!(spec.parameters_schema, Some(expected)); } } ================================================ FILE: swiftide-agents/src/tools/local_executor.rs ================================================ //! Local executor for running tools on the local machine. //! //! By default will use the current directory as the working directory. use std::{ collections::HashMap, path::{Path, PathBuf}, process::Stdio, time::Duration, }; use anyhow::{Context as _, Result}; use async_trait::async_trait; use derive_builder::Builder; use swiftide_core::{Command, CommandError, CommandOutput, Loader, ToolExecutor}; use swiftide_indexing::loaders::FileLoader; use tokio::{ io::{AsyncBufReadExt as _, AsyncWriteExt as _}, task::JoinHandle, time, }; #[derive(Debug, Clone, Builder)] pub struct LocalExecutor { #[builder(default = ".".into(), setter(into))] workdir: PathBuf, #[builder(default)] default_timeout: Option, /// Clears env variables before executing commands. #[builder(default)] pub(crate) env_clear: bool, /// Remove these environment variables before executing commands. #[builder(default, setter(into))] pub(crate) env_remove: Vec, /// Set these environment variables before executing commands. #[builder(default, setter(into))] pub(crate) envs: HashMap, } impl Default for LocalExecutor { fn default() -> Self { LocalExecutor { workdir: ".".into(), default_timeout: None, env_clear: false, env_remove: Vec::new(), envs: HashMap::new(), } } } impl LocalExecutor { pub fn new(workdir: impl Into) -> Self { LocalExecutor { workdir: workdir.into(), default_timeout: None, env_clear: false, env_remove: Vec::new(), envs: HashMap::new(), } } pub fn builder() -> LocalExecutorBuilder { LocalExecutorBuilder::default() } fn resolve_workdir(&self, cmd: &Command) -> PathBuf { match cmd.current_dir_path() { Some(path) if path.is_absolute() => path.to_path_buf(), Some(path) => self.workdir.join(path), None => self.workdir.clone(), } } fn resolve_timeout(&self, cmd: &Command) -> Option { cmd.timeout_duration().copied().or(self.default_timeout) } #[allow(clippy::too_many_lines)] async fn exec_shell( &self, cmd: &str, workdir: &Path, timeout: Option, ) -> Result { let lines: Vec<&str> = cmd.lines().collect(); let mut child = if let Some(first_line) = lines.first() && first_line.starts_with("#!") { let interpreter = first_line.trim_start_matches("#!/usr/bin/env ").trim(); tracing::info!(interpreter, "detected shebang; running as script"); let mut command = tokio::process::Command::new(interpreter); if self.env_clear { tracing::info!("clearing environment variables"); command.env_clear(); } for var in &self.env_remove { tracing::info!(var, "clearing environment variable"); command.env_remove(var); } for (key, value) in &self.envs { tracing::info!(key, "setting environment variable"); command.env(key, value); } let mut child = command .current_dir(workdir) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn()?; if let Some(mut stdin) = child.stdin.take() { let body = lines[1..].join("\n"); stdin.write_all(body.as_bytes()).await?; } child } else { tracing::info!("no shebang detected; running as command"); let mut command = tokio::process::Command::new("sh"); // Treat as shell command command.arg("-c").arg(cmd).current_dir(workdir); if self.env_clear { tracing::info!("clearing environment variables"); command.env_clear(); } for var in &self.env_remove { tracing::info!(var, "clearing environment variable"); command.env_remove(var); } for (key, value) in &self.envs { tracing::info!(key, "setting environment variable"); command.env(key, value); } command .current_dir(workdir) .stdin(Stdio::null()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn()? }; let stdout_task = if let Some(stdout) = child.stdout.take() { Some(tokio::spawn(async move { let mut lines = tokio::io::BufReader::new(stdout).lines(); let mut out = Vec::new(); while let Ok(Some(line)) = lines.next_line().await { out.push(line); } out })) } else { tracing::warn!("Command has no stdout"); None }; let stderr_task = if let Some(stderr) = child.stderr.take() { Some(tokio::spawn(async move { let mut lines = tokio::io::BufReader::new(stderr).lines(); let mut out = Vec::new(); while let Ok(Some(line)) = lines.next_line().await { out.push(line); } out })) } else { tracing::warn!("Command has no stderr"); None }; let status = match timeout { Some(limit) => { if let Ok(result) = time::timeout(limit, child.wait()).await { result.map_err(|err| CommandError::ExecutorError(err.into()))? } else { tracing::warn!(?limit, "command exceeded timeout; terminating"); if let Err(err) = child.start_kill() { tracing::warn!(?err, "failed to start kill on timed out command"); } if let Err(err) = child.wait().await { tracing::warn!(?err, "failed to reap command after timeout"); } let (stdout, stderr) = Self::collect_process_output(stdout_task, stderr_task).await; let cmd_output = Self::merge_output(&stdout, &stderr); return Err(CommandError::TimedOut { timeout: limit, output: cmd_output, }); } } None => child .wait() .await .map_err(|err| CommandError::ExecutorError(err.into()))?, }; let (stdout, stderr) = Self::collect_process_output(stdout_task, stderr_task).await; let cmd_output = Self::merge_output(&stdout, &stderr); if status.success() { Ok(cmd_output) } else { Err(CommandError::NonZeroExit(cmd_output)) } } async fn exec_read_file( &self, workdir: &Path, path: &Path, timeout: Option, ) -> Result { let path = if path.is_absolute() { path.to_path_buf() } else { workdir.join(path) }; let read_future = fs_err::tokio::read(&path); let output = match timeout { Some(limit) => match time::timeout(limit, read_future).await { Ok(result) => result?, Err(_) => { return Err(CommandError::TimedOut { timeout: limit, output: CommandOutput::empty(), }); } }, None => read_future.await?, }; Ok(String::from_utf8(output) .context("Failed to parse read file output")? .into()) } async fn exec_write_file( &self, workdir: &Path, path: &Path, content: &str, timeout: Option, ) -> Result { let path = if path.is_absolute() { path.to_path_buf() } else { workdir.join(path) }; if let Some(parent) = path.parent() { let _ = fs_err::tokio::create_dir_all(parent).await; } let write_future = fs_err::tokio::write(&path, content); match timeout { Some(limit) => match time::timeout(limit, write_future).await { Ok(result) => result?, Err(_) => { return Err(CommandError::TimedOut { timeout: limit, output: CommandOutput::empty(), }); } }, None => write_future.await?, } Ok(CommandOutput::empty()) } async fn collect_process_output( stdout_task: Option>>, stderr_task: Option>>, ) -> (Vec, Vec) { let stdout = match stdout_task { Some(task) => match task.await { Ok(lines) => lines, Err(err) => { tracing::warn!(?err, "failed to collect stdout from command"); Vec::new() } }, None => Vec::new(), }; let stderr = match stderr_task { Some(task) => match task.await { Ok(lines) => lines, Err(err) => { tracing::warn!(?err, "failed to collect stderr from command"); Vec::new() } }, None => Vec::new(), }; (stdout, stderr) } fn merge_output(stdout: &[String], stderr: &[String]) -> CommandOutput { stdout .iter() .chain(stderr.iter()) .cloned() .collect::>() .join("\n") .into() } } #[async_trait] impl ToolExecutor for LocalExecutor { /// Execute a `Command` on the local machine #[tracing::instrument(skip_self)] async fn exec_cmd(&self, cmd: &Command) -> Result { let workdir = __self.resolve_workdir(cmd); let timeout = __self.resolve_timeout(cmd); match cmd { Command::Shell { command, .. } => __self.exec_shell(command, &workdir, timeout).await, Command::ReadFile { path, .. } => __self.exec_read_file(&workdir, path, timeout).await, Command::WriteFile { path, content, .. } => { __self .exec_write_file(&workdir, path, content, timeout) .await } _ => unimplemented!("Unsupported command: {cmd:?}"), } } async fn stream_files( &self, path: &Path, extensions: Option>, ) -> Result> { let mut loader = FileLoader::new(path); if let Some(extensions) = extensions { loader = loader.with_extensions(&extensions); } Ok(loader.into_stream()) } } #[cfg(test)] mod tests { use super::*; use futures_util::StreamExt as _; use indoc::indoc; use std::{path::Path, sync::Arc, time::Duration}; use swiftide_core::{Command, ExecutorExt, ToolExecutor}; use temp_dir::TempDir; #[tokio::test] async fn test_local_executor_write_and_read_file() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; // Define the file path and content let file_path = temp_path.join("test_file.txt"); let file_content = "Hello, world!"; // Write a shell command to create a file with the specified content let write_cmd = Command::shell(format!("echo '{}' > {}", file_content, file_path.display())); // Execute the write command executor.exec_cmd(&write_cmd).await?; // Verify that the file was created successfully assert!(file_path.exists()); // Write a shell command to read the file's content let read_cmd = Command::shell(format!("cat {}", file_path.display())); // Execute the read command let output = executor.exec_cmd(&read_cmd).await?; // Verify that the content read from the file matches the expected content assert_eq!(output.to_string(), format!("{file_content}")); let output = executor .exec_cmd(&Command::read_file(&file_path)) .await .unwrap(); assert_eq!(output.to_string(), format!("{file_content}\n")); Ok(()) } #[tokio::test] async fn test_local_executor_echo_hello_world() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; // Define the echo command let echo_cmd = Command::shell("echo 'hello world'"); // Execute the echo command let output = executor.exec_cmd(&echo_cmd).await?; // Verify that the output matches the expected content assert_eq!(output.to_string().trim(), "hello world"); Ok(()) } #[tokio::test] async fn test_local_executor_shell_timeout() -> anyhow::Result<()> { let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; let mut cmd = Command::shell("echo ready && sleep 1 && echo done"); cmd.timeout(Duration::from_millis(100)); match executor.exec_cmd(&cmd).await { Err(CommandError::TimedOut { timeout, output }) => { assert_eq!(timeout, Duration::from_millis(100)); assert!(output.to_string().contains("ready")); } other => anyhow::bail!("expected timeout error, got {other:?}"), } Ok(()) } #[tokio::test] async fn test_local_executor_default_timeout_applies() -> anyhow::Result<()> { let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); let executor = LocalExecutorBuilder::default() .workdir(temp_path.to_path_buf()) .default_timeout(Some(Duration::from_millis(100))) .build()?; match executor.exec_cmd(&Command::shell("sleep 1")).await { Err(CommandError::TimedOut { timeout, output }) => { assert_eq!(timeout, Duration::from_millis(100)); assert!(output.to_string().is_empty()); } other => anyhow::bail!("expected default timeout, got {other:?}"), } Ok(()) } #[tokio::test] async fn test_local_executor_clear_env() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), env_clear: true, ..Default::default() }; // Define the echo command let echo_cmd = Command::shell("printenv"); // Execute the echo command let output = executor.exec_cmd(&echo_cmd).await?.to_string(); // Verify that the output matches the expected content // assert_eq!(output.to_string().trim(), ""); assert!(!output.contains("CARGO_PKG_VERSION"), "{output}"); Ok(()) } #[tokio::test] async fn test_local_executor_add_env() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), envs: HashMap::from([("TEST_ENV".to_string(), "HELLO".to_string())]), ..Default::default() }; // Define the echo command let echo_cmd = Command::shell("printenv"); // Execute the echo command let output = executor.exec_cmd(&echo_cmd).await?.to_string(); // Verify that the output matches the expected content // assert_eq!(output.to_string().trim(), ""); assert!(output.contains("TEST_ENV=HELLO"), "{output}"); // Double tap its included by default assert!(output.contains("CARGO_PKG_VERSION"), "{output}"); Ok(()) } #[tokio::test] async fn test_local_executor_env_remove() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), env_remove: vec!["CARGO_PKG_VERSION".to_string()], ..Default::default() }; // Define the echo command let echo_cmd = Command::shell("printenv"); // Execute the echo command let output = executor.exec_cmd(&echo_cmd).await?.to_string(); // Verify that the output matches the expected content // assert_eq!(output.to_string().trim(), ""); assert!(!output.contains("CARGO_PKG_VERSION="), "{output}"); Ok(()) } #[tokio::test] async fn test_local_executor_run_shebang() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; let script = r#"#!/usr/bin/env python3 print("hello from python") print(1 + 2)"#; // Execute the echo command let output = executor .exec_cmd(&Command::shell(script)) .await? .to_string(); // Verify that the output matches the expected content assert!(output.contains("hello from python")); assert!(output.contains('3')); Ok(()) } #[tokio::test] async fn test_local_executor_multiline_with_quotes() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; // Define the file path and content let file_path = "test_file2.txt"; let file_content = indoc! {r#" fn main() { println!("Hello, world!"); } "#}; // Write a shell command to create a file with the specified content let write_cmd = Command::shell(format!("echo '{file_content}' > {file_path}")); // Execute the write command executor.exec_cmd(&write_cmd).await?; // Write a shell command to read the file's content let read_cmd = Command::shell(format!("cat {file_path}")); // Execute the read command let output = executor.exec_cmd(&read_cmd).await?; // Verify that the content read from the file matches the expected content assert_eq!(output.to_string(), format!("{file_content}")); Ok(()) } #[tokio::test] async fn test_local_executor_write_and_read_file_commands() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; // Define the file path and content let file_path = temp_path.join("test_file.txt"); let file_content = "Hello, world!"; // Assert that the file does not exist and it gives the correct error let cmd = Command::read_file(file_path.clone()); let result = executor.exec_cmd(&cmd).await; if let Err(err) = result { assert!(matches!(err, CommandError::NonZeroExit(..))); } else { panic!("Expected error but got {result:?}"); } // Create a write command let write_cmd = Command::write_file(file_path.clone(), file_content.to_string()); // Execute the write command executor.exec_cmd(&write_cmd).await?; // Verify that the file was created successfully assert!(file_path.exists()); // Create a read command let read_cmd = Command::read_file(file_path.clone()); // Execute the read command let output = executor.exec_cmd(&read_cmd).await?.output; // Verify that the content read from the file matches the expected content assert_eq!(output, file_content); Ok(()) } #[tokio::test] async fn test_local_executor_stream_files() -> anyhow::Result<()> { // Create a temporary directory let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); // Create some test files in the temporary directory fs_err::write(temp_path.join("file1.txt"), "Content of file 1")?; fs_err::write(temp_path.join("file2.txt"), "Content of file 2")?; fs_err::write(temp_path.join("file3.rs"), "Content of file 3")?; // Instantiate LocalExecutor with the temporary directory as workdir let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; // Stream files with no extensions filter let stream = executor.stream_files(temp_path, None).await?; let files: Vec<_> = stream.collect().await; assert_eq!(files.len(), 3); // Stream files with a specific extension filter let stream = executor .stream_files(temp_path, Some(vec!["txt".to_string()])) .await?; let txt_files: Vec<_> = stream.collect().await; assert_eq!(txt_files.len(), 2); Ok(()) } #[tokio::test] async fn test_local_executor_honors_workdir() -> anyhow::Result<()> { use std::fs; use temp_dir::TempDir; // 1. Create a temp dir and instantiate executor let temp_dir = TempDir::new()?; let temp_path = temp_dir.path(); let executor = LocalExecutor { workdir: temp_path.to_path_buf(), ..Default::default() }; // 2. Run a shell command in workdir and check output is workdir let pwd_cmd = Command::shell("pwd"); let pwd_output = executor.exec_cmd(&pwd_cmd).await?.to_string(); let pwd_path = std::fs::canonicalize(pwd_output.trim())?; let temp_path = std::fs::canonicalize(temp_path)?; assert_eq!(pwd_path, temp_path); // 3. Write a file using WriteFile (should land in workdir) let fname = "workdir_check.txt"; let write_cmd = Command::write_file(fname, "test123"); executor.exec_cmd(&write_cmd).await?; // 4. Assert file exists in workdir, not current dir let expected_path = temp_path.join(fname); assert!(expected_path.exists()); assert!(!Path::new(fname).exists()); // 5. Write/read using ReadFile let read_cmd = Command::read_file(fname); let read_output = executor.exec_cmd(&read_cmd).await?.to_string(); assert_eq!(read_output.trim(), "test123"); // 6. Clean up fs::remove_file(&expected_path)?; Ok(()) } #[tokio::test] async fn test_local_executor_command_current_dir() -> anyhow::Result<()> { use std::fs; use temp_dir::TempDir; let temp_dir = TempDir::new()?; let base_path = temp_dir.path(); let executor = LocalExecutor { workdir: base_path.to_path_buf(), ..Default::default() }; let nested_dir = base_path.join("nested"); fs::create_dir_all(&nested_dir)?; let mut pwd_cmd = Command::shell("pwd"); pwd_cmd.current_dir(Path::new("nested")); let pwd_output = executor.exec_cmd(&pwd_cmd).await?.to_string(); let pwd_path = std::fs::canonicalize(pwd_output.trim())?; assert_eq!(pwd_path, std::fs::canonicalize(&nested_dir)?); let mut write_cmd = Command::write_file("file.txt", "hello"); write_cmd.current_dir(Path::new("nested")); executor.exec_cmd(&write_cmd).await?; assert!(!base_path.join("file.txt").exists()); assert!(nested_dir.join("file.txt").exists()); let mut read_cmd = Command::read_file("file.txt"); read_cmd.current_dir(Path::new("nested")); let read_output = executor.exec_cmd(&read_cmd).await?.to_string(); assert_eq!(read_output.trim(), "hello"); Ok(()) } #[tokio::test] async fn test_local_executor_current_dir() -> anyhow::Result<()> { let temp_dir = TempDir::new()?; let base_path = temp_dir.path(); let executor = LocalExecutor { workdir: base_path.to_path_buf(), ..Default::default() }; let nested = executor.scoped("nested"); nested .exec_cmd(&Command::write_file("file.txt", "hello")) .await?; assert!(!base_path.join("file.txt").exists()); assert!(base_path.join("nested").join("file.txt").exists()); assert_eq!(executor.workdir, base_path); Ok(()) } #[tokio::test] async fn test_local_executor_current_dir_dyn() -> anyhow::Result<()> { let temp_dir = TempDir::new()?; let base_path = temp_dir.path(); let executor = LocalExecutor { workdir: base_path.to_path_buf(), ..Default::default() }; let dyn_exec: Arc = Arc::new(executor.clone()); let nested = dyn_exec.scoped("nested"); nested .exec_cmd(&Command::write_file("nested_file.txt", "hello")) .await?; assert!(base_path.join("nested").join("nested_file.txt").exists()); assert!(!base_path.join("nested_file.txt").exists()); Ok(()) } } ================================================ FILE: swiftide-agents/src/tools/mcp.rs ================================================ //! Add tools provided by an MCP server to an agent //! //! Uses the `rmcp` crate to connect to an MCP server and list available tools, and invoke them //! //! Supports any transport that the `rmcp` crate supports use std::borrow::Cow; use std::sync::Arc; use anyhow::{Context as _, Result}; use async_trait::async_trait; use rmcp::RoleClient; use rmcp::ServiceExt; use rmcp::model::{CallToolRequestParams, ClientInfo, Implementation, InitializeRequestParams}; use rmcp::service::RunningService; use rmcp::transport::IntoTransport; use schemars::Schema; use serde::{Deserialize, Serialize}; use swiftide_core::CommandError; use swiftide_core::chat_completion::ToolCall; use swiftide_core::{ Tool, ToolBox, chat_completion::{ToolSpec, errors::ToolError}, }; use tokio::sync::RwLock; /// A filter to apply to the available tools #[derive(Clone, Debug, Serialize, Deserialize)] pub enum ToolFilter { Blacklist(Vec), Whitelist(Vec), } /// Connects to an MCP server and provides tools at runtime to the agent. /// /// WARN: The rmcp has a quirky feature to serve from `()`. This does not work; serve from /// `ClientInfo` instead, or from the transport and `Swiftide` will handle the rest. #[derive(Clone)] pub struct McpToolbox { service: Arc>>>, /// Optional human readable name for the toolbox name: Option, filter: Arc>, } impl McpToolbox { /// Blacklist tools by name, the agent will not be able to use these tools pub fn with_blacklist, I: IntoIterator>( &mut self, blacklist: I, ) -> &mut Self { let list = blacklist.into_iter().map(Into::into).collect::>(); self.filter = Some(ToolFilter::Blacklist(list)).into(); self } /// Whitelist tools by name, the agent will only be able to use these tools pub fn with_whitelist, I: IntoIterator>( &mut self, blacklist: I, ) -> &mut Self { let list = blacklist.into_iter().map(Into::into).collect::>(); self.filter = Some(ToolFilter::Whitelist(list)).into(); self } /// Apply a custom filter to the tools pub fn with_filter(&mut self, filter: ToolFilter) -> &mut Self { self.filter = Some(filter).into(); self } /// Apply an optional name to the toolbox pub fn with_name(&mut self, name: impl Into) -> &mut Self { self.name = Some(name.into()); self } pub fn name(&self) -> &str { self.name.as_deref().unwrap_or("MCP Toolbox") } /// Create a new toolbox from a transport /// /// # Errors /// /// Errors if the transport fails to connect pub async fn try_from_transport< E: std::error::Error + From + Send + Sync + 'static, A, >( transport: impl IntoTransport, ) -> Result { let info = Self::default_client_info(); let service = Arc::new(RwLock::new(Some(info.serve(transport).await?))); Ok(Self { service, filter: None.into(), name: None, }) } /// Create a new toolbox from a running service pub fn from_running_service( service: RunningService, ) -> Self { Self { service: Arc::new(RwLock::new(Some(service))), filter: None.into(), name: None, } } fn default_client_info() -> ClientInfo { ClientInfo { client_info: Implementation { name: "swiftide".into(), version: env!("CARGO_PKG_VERSION").into(), title: None, description: None, icons: None, website_url: None, }, ..Default::default() } } /// Disconnects from the MCP server if it is running /// /// If it is not running, an Ok is returned and it logs a tracing message /// /// # Errors /// /// Errors if the service is running but cannot be stopped pub async fn cancel(&mut self) -> Result<()> { let mut lock = self.service.write().await; let Some(service) = std::mem::take(&mut *lock) else { tracing::warn!("mcp server is not running"); return Ok(()); }; tracing::debug!(name = self.name(), "Stopping mcp server"); service .cancel() .await .context("failed to stop mcp server")?; Ok(()) } } #[async_trait] impl ToolBox for McpToolbox { #[tracing::instrument(skip_all)] async fn available_tools(&self) -> Result>> { let Some(service) = &*self.service.read().await else { anyhow::bail!("No service available"); }; tracing::debug!(name = self.name(), "Connecting to mcp server"); let peer_info = service.peer_info(); tracing::debug!(?peer_info, name = self.name(), "Connected to mcp server"); tracing::debug!(name = self.name(), "Listing tools from mcp server"); let tools = service .list_all_tools() .await .context("Failed to list tools")?; let filter = self.filter.as_ref(); let mut server_name = peer_info .map_or("mcp", |info| info.server_info.name.as_str()) .trim() .to_owned(); if server_name.is_empty() { server_name = "mcp".into(); } let tools = tools .into_iter() .filter(|tool| match &filter { Some(ToolFilter::Blacklist(blacklist)) => { !blacklist.iter().any(|blocked| blocked == &tool.name) } Some(ToolFilter::Whitelist(whitelist)) => { whitelist.iter().any(|allowed| allowed == &tool.name) } None => true, }) .map(|tool| { let schema_value = tool.schema_as_json_value(); tracing::trace!( schema = ?schema_value, "Parsing tool input schema for {}", tool.name ); let mut tool_spec_builder = ToolSpec::builder(); // Preallocate to avoid repeated string growth. let mut registered_name = String::with_capacity(server_name.len() + tool.name.len() + 1); registered_name.push_str(&server_name); registered_name.push(':'); registered_name.push_str(&tool.name); tool_spec_builder.name(registered_name.clone()); if let Some(description) = tool.description { tool_spec_builder.description(description); } match schema_value { serde_json::Value::Null => {} value => { let schema: Schema = serde_json::from_value(value) .context("Failed to parse tool input schema")?; tool_spec_builder.parameters_schema(schema); } } let tool_spec = tool_spec_builder .build() .context("Failed to build tool spec")?; Ok(Box::new(McpTool { client: Arc::clone(&self.service), registered_name, server_tool_name: tool.name.into(), tool_spec, }) as Box) }) .collect::>>() .context("Failed to build mcp tool specs")?; Ok(tools) } fn name(&self) -> Cow<'_, str> { self.name().into() } } #[derive(Clone)] struct McpTool { client: Arc>>>, registered_name: String, server_tool_name: String, tool_spec: ToolSpec, } #[async_trait] impl Tool for McpTool { async fn invoke( &self, _agent_context: &dyn swiftide_core::AgentContext, tool_call: &ToolCall, ) -> Result< swiftide_core::chat_completion::ToolOutput, swiftide_core::chat_completion::errors::ToolError, > { let args = match tool_call.args() { Some(args) => Some(serde_json::from_str(args).map_err(ToolError::WrongArguments)?), None => None, }; let request = CallToolRequestParams { meta: None, name: self.server_tool_name.clone().into(), arguments: args, task: None, }; let Some(service) = &*self.client.read().await else { return Err( CommandError::ExecutorError(anyhow::anyhow!("mcp server is not running")).into(), ); }; tracing::debug!(request = ?request, tool = self.name().as_ref(), "Invoking mcp tool"); let response = service .call_tool(request) .await .context("Failed to call tool")?; tracing::debug!(response = ?response, tool = self.name().as_ref(), "Received response from mcp tool"); let rmcp::model::CallToolResult { content, structured_content, is_error, .. } = response; let content = if content.is_empty() { structured_content.map(|structured| structured.to_string()) } else { let mut iter = content.into_iter().filter_map(|c| match c.raw { rmcp::model::RawContent::Text(rmcp::model::RawTextContent { text, .. }) => { Some(text) } _ => None, }); iter.next().map(|first| { let mut joined = first; for part in iter { joined.push('\n'); joined.push_str(&part); } joined }) }; if is_error.unwrap_or(false) { let content = content.unwrap_or_else(|| "Unknown error".to_string()); return Err(ToolError::Unknown(anyhow::anyhow!( "Failed to execute mcp tool: {content}" ))); } match content { Some(content) => Ok(content.into()), // Some MCP tools may legitimately return no textual or structured content // while still being successful (e.g. optional echo with null input). None => Ok("Tool executed successfully".into()), } } fn name(&self) -> std::borrow::Cow<'_, str> { self.registered_name.as_str().into() } fn tool_spec(&self) -> ToolSpec { self.tool_spec.clone() } } #[cfg(test)] mod tests { use super::*; use copied_from_rmcp::Calculator; use rmcp::serve_server; use tokio::net::{UnixListener, UnixStream}; const SOCKET_PATH: &str = "/tmp/swiftide-mcp.sock"; const EXPECTED_PREFIX: &str = "rmcp"; #[allow(clippy::similar_names)] #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_socket() { let _ = std::fs::remove_file(SOCKET_PATH); match UnixListener::bind(SOCKET_PATH) { Ok(unix_listener) => { println!("Server successfully listening on {SOCKET_PATH}"); tokio::spawn(server(unix_listener)); } Err(e) => { println!("Unable to bind to {SOCKET_PATH}: {e}"); } } let client = client().await.unwrap(); let t = client.available_tools().await.unwrap(); assert_eq!(client.available_tools().await.unwrap().len(), 3); let mut names = t.iter().map(|t| t.name().into_owned()).collect::>(); names.sort(); assert_eq!( names, [ format!("{EXPECTED_PREFIX}:optional"), format!("{EXPECTED_PREFIX}:sub"), format!("{EXPECTED_PREFIX}:sum") ] ); let sum_name = format!("{EXPECTED_PREFIX}:sum"); let sum_tool = t.iter().find(|t| t.name().as_ref() == sum_name).unwrap(); let mut builder = ToolCall::builder() .id("some") .args(r#"{"b": "hello"}"#) .name("test") .name("test") .to_owned(); assert_eq!(sum_tool.tool_spec().name, sum_name); let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap(); let result = sum_tool .invoke(&(), &tool_call) .await .unwrap() .content() .unwrap() .to_string(); assert_eq!(result, "30"); let sub_name = format!("{EXPECTED_PREFIX}:sub"); let sub_tool = t.iter().find(|t| t.name().as_ref() == sub_name).unwrap(); assert_eq!(sub_tool.tool_spec().name, sub_name); let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap(); let result = sub_tool .invoke(&(), &tool_call) .await .unwrap() .content() .unwrap() .to_string(); assert_eq!(result, "-10"); // The input schema type for the input param is string with null allowed let optional_name = format!("{EXPECTED_PREFIX}:optional"); let optional_tool = t .iter() .find(|t| t.name().as_ref() == optional_name) .unwrap(); assert_eq!(optional_tool.tool_spec().name, optional_name); let spec = optional_tool.tool_spec(); let schema = spec .parameters_schema .expect("optional tool should expose a schema"); let schema_json = serde_json::to_value(schema).unwrap(); let _text_prop = schema_json .get("properties") .and_then(|props| props.get("text")) .expect("optional tool schema must include `text`"); let tool_call = builder.args(r#"{"text": "hello"}"#).build().unwrap(); let result = optional_tool .invoke(&(), &tool_call) .await .unwrap() .content() .unwrap() .to_string(); assert_eq!(result, "hello"); let tool_call = builder.args(r#"{"text": null}"#).build().unwrap(); let result = optional_tool .invoke(&(), &tool_call) .await .unwrap() .content() .unwrap() .to_string(); assert_eq!(result, ""); // Clean up socket file let _ = std::fs::remove_file(SOCKET_PATH); } async fn server(unix_listener: UnixListener) -> anyhow::Result<()> { while let Ok((stream, addr)) = unix_listener.accept().await { println!("Client connected: {addr:?}"); tokio::spawn(async move { match serve_server(Calculator::new(), stream).await { Ok(server) => { println!("Server initialized successfully"); if let Err(e) = server.waiting().await { println!("Error while server waiting: {e:?}"); } } Err(e) => println!("Server initialization failed: {e:?}"), } anyhow::Ok(()) }); } Ok(()) } async fn client() -> anyhow::Result { println!("Client connecting to {SOCKET_PATH}"); let stream = UnixStream::connect(SOCKET_PATH).await?; // let client = serve_client((), stream).await?; let client = McpToolbox::try_from_transport(stream).await?; println!("Client connected and initialized successfully"); Ok(client) } #[allow(clippy::unused_self)] mod copied_from_rmcp { use rmcp::{ ErrorData as McpError, ServerHandler, handler::server::{tool::ToolRouter, wrapper::Parameters}, model::{CallToolResult, Content, ServerCapabilities, ServerInfo}, schemars, tool, tool_handler, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct Request { pub a: i32, pub b: i32, } #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct OptRequest { pub text: Option, } #[derive(Debug, Clone)] pub struct Calculator { tool_router: ToolRouter, } #[rmcp::tool_router] impl Calculator { pub fn new() -> Self { Self { tool_router: Self::tool_router(), } } #[allow(clippy::unnecessary_wraps)] #[tool(description = "Calculate the sum of two numbers")] fn sum( &self, Parameters(Request { a, b }): Parameters, ) -> Result { Ok(CallToolResult::success(vec![Content::text( (a + b).to_string(), )])) } #[allow(clippy::unnecessary_wraps)] #[tool(description = "Calculate the sum of two numbers")] fn sub( &self, Parameters(Request { a, b }): Parameters, ) -> Result { Ok(CallToolResult::success(vec![Content::text( (a - b).to_string(), )])) } #[allow(clippy::unnecessary_wraps)] #[tool(description = "Optional echo")] fn optional( &self, Parameters(OptRequest { text }): Parameters, ) -> Result { Ok(CallToolResult::success(vec![Content::text( text.unwrap_or_default(), )])) } } #[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { instructions: Some("A simple calculator".into()), capabilities: ServerCapabilities::builder().enable_tools().build(), ..Default::default() } } } } } ================================================ FILE: swiftide-agents/src/tools/mod.rs ================================================ //! Default tools and executor for agents pub mod arg_preprocessor; pub mod control; pub mod local_executor; /// Add tools from a Model Context Protocol endpoint #[cfg(feature = "mcp")] pub mod mcp; ================================================ FILE: swiftide-agents/src/util.rs ================================================ //! Internal utility functions and macros for anything agent /// Simple macro to consistently call hooks and clean up the code #[macro_export] macro_rules! invoke_hooks { (OnStream, $self_expr:expr $(, $arg:expr)* ) => {{ // For streaming we log less and only on the trace level for hook in $self_expr.hooks_by_type(HookTypes::OnStream) { // Downcast to the correct closure variant if let Hook::OnStream(hook_fn) = hook { // Create a tracing span for instrumentation let span = tracing::trace_span!( "hook", "otel.name" = format!("hook.{:?}", HookTypes::OnStream) ); // Call the hook, instrument, and log on failure if let Err(err) = hook_fn($self_expr $(, $arg)*) .instrument(span.or_current()) .await { tracing::error!( "Error in {hooktype} hook: {err}", hooktype = HookTypes::OnStream, ); } } } }}; ($hook_type:ident, $self_expr:expr $(, $arg:expr)* ) => {{ // Iterate through every hook matching `HookTypes::$hook_type` for hook in $self_expr.hooks_by_type(HookTypes::$hook_type) { // Downcast to the correct closure variant if let Hook::$hook_type(hook_fn) = hook { // Create a tracing span for instrumentation let span = tracing::info_span!( "hook", "otel.name" = format!("hook.{:?}", HookTypes::$hook_type) ); tracing::debug!("Calling {} hook", HookTypes::$hook_type); // Call the hook, instrument, and log on failure if let Err(err) = hook_fn($self_expr $(, $arg)*) .instrument(span.or_current()) .await { tracing::error!( "Error in {hooktype} hook: {err}", hooktype = HookTypes::$hook_type, ); } } } }}; } ================================================ FILE: swiftide-core/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-core" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] anyhow = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } async-trait = { workspace = true } futures-util = { workspace = true } tokio-stream = { workspace = true } itertools = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } mockall = { workspace = true, optional = true } lazy_static = { workspace = true } derive_builder = { workspace = true } dyn-clone = { workspace = true } pin-project = { workspace = true } thiserror = { workspace = true } metrics = { workspace = true, optional = true } schemars = { workspace = true, features = ["derive"] } async-openai = { workspace = true, optional = true, features = ["chat-completion-types", "embedding-types", "response-types"] } tera = { workspace = true } uuid = { workspace = true, features = ["v4", "v3"] } pretty_assertions = { workspace = true, optional = true } # Integrations qdrant-client = { workspace = true, optional = true } backoff = { version = "0.4.0", features = ["futures", "tokio"] } [dev-dependencies] test-case = { workspace = true } test-log = { workspace = true } tokio = { workspace = true, features = ["time", "test-util"] } tokio-stream = { workspace = true } [features] defaults = ["truncate-debug"] test-utils = ["dep:mockall", "dep:pretty_assertions"] qdrant = ["dep:qdrant-client"] # Truncates large debug outputs on pipeline nodes truncate-debug = [] metrics = ["dep:metrics"] json-schema = [] openai = ["dep:async-openai"] [lints] workspace = true [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: swiftide-core/README.md ================================================ # Swiftide-core Core crate includes low level types and traits for swiftide that are used by other crates. ================================================ FILE: swiftide-core/src/agent_traits.rs ================================================ use std::{ borrow::Cow, path::{Path, PathBuf}, sync::{Arc, Mutex}, time::Duration, }; use crate::{ chat_completion::{ChatMessage, ToolCall}, indexing::IndexingStream, }; use anyhow::Result; use async_trait::async_trait; use dyn_clone::DynClone; use serde::{Deserialize, Serialize}; use thiserror::Error; /// A `ToolExecutor` provides an interface for agents to interact with a system /// in an isolated context. /// /// When starting up an agent, it's context expects an executor. For example, /// you might want your coding agent to work with a fresh, isolated set of files, /// separated from the rest of the system. /// /// See `swiftide-docker-executor` for an executor that uses Docker. By default /// the executor is a local executor. /// /// Additionally, the executor can be used stream files files for indexing. #[async_trait] pub trait ToolExecutor: Send + Sync + DynClone { /// Execute a command in the executor async fn exec_cmd(&self, cmd: &Command) -> Result; /// Stream files from the executor async fn stream_files( &self, path: &Path, extensions: Option>, ) -> Result>; } dyn_clone::clone_trait_object!(ToolExecutor); /// Lightweight executor wrapper that applies a default working directory to forwarded commands. /// /// Most callers should construct this via [`ExecutorExt::scoped`], which borrows the underlying /// executor and only clones commands/paths when the scope actually changes their resolution. #[derive(Debug, Clone)] pub struct ScopedExecutor { executor: E, scope: PathBuf, } impl ScopedExecutor { /// Build a new wrapper around `executor` that prefixes relative paths with `scope`. pub fn new(executor: E, scope: impl Into) -> Self { Self { executor, scope: scope.into(), } } /// Returns either the original command or a scoped clone depending on the current directory. fn apply_scope<'a>(&'a self, cmd: &'a Command) -> Cow<'a, Command> { match cmd.current_dir_path() { Some(path) if path.is_absolute() || self.scope.as_os_str().is_empty() => { Cow::Borrowed(cmd) } Some(path) => { let mut scoped = cmd.clone(); scoped.current_dir(self.scope.join(path)); Cow::Owned(scoped) } None if self.scope.as_os_str().is_empty() => Cow::Borrowed(cmd), None => { let mut scoped = cmd.clone(); scoped.current_dir(self.scope.clone()); Cow::Owned(scoped) } } } /// Returns a path adjusted for the scope when the provided path is relative. fn scoped_path<'a>(&'a self, path: &'a Path) -> Cow<'a, Path> { if path.is_absolute() || self.scope.as_os_str().is_empty() { Cow::Borrowed(path) } else { Cow::Owned(self.scope.join(path)) } } /// Access the inner executor. pub fn inner(&self) -> &E { &self.executor } /// Expose the scope that will be applied to relative paths. pub fn scope(&self) -> &Path { &self.scope } } #[async_trait] impl ToolExecutor for ScopedExecutor where E: ToolExecutor + Send + Sync + Clone, { async fn exec_cmd(&self, cmd: &Command) -> Result { let scoped_cmd = self.apply_scope(cmd); self.executor.exec_cmd(scoped_cmd.as_ref()).await } async fn stream_files( &self, path: &Path, extensions: Option>, ) -> Result> { let scoped_path = self.scoped_path(path); self.executor .stream_files(scoped_path.as_ref(), extensions) .await } } /// Convenience methods for scoping executors without cloning them. pub trait ExecutorExt { /// Borrow `self` and return a wrapper that resolves relative operations inside `path`. fn scoped(&self, path: impl Into) -> ScopedExecutor<&Self>; fn scoped_owned(self, path: impl Into) -> ScopedExecutor where Self: Sized; } impl ExecutorExt for T where T: ToolExecutor, { fn scoped(&self, path: impl Into) -> ScopedExecutor<&Self> { ScopedExecutor::new(self, path) } fn scoped_owned(self, path: impl Into) -> ScopedExecutor { ScopedExecutor::new(self, path) } } #[async_trait] impl ToolExecutor for &T where T: ToolExecutor + ?Sized, { async fn exec_cmd(&self, cmd: &Command) -> Result { (**self).exec_cmd(cmd).await } async fn stream_files( &self, path: &Path, extensions: Option>, ) -> Result> { (**self).stream_files(path, extensions).await } } #[async_trait] impl ToolExecutor for Arc { async fn exec_cmd(&self, cmd: &Command) -> Result { self.as_ref().exec_cmd(cmd).await } async fn stream_files( &self, path: &Path, extensions: Option>, ) -> Result> { self.as_ref().stream_files(path, extensions).await } } #[async_trait] impl ToolExecutor for Box { async fn exec_cmd(&self, cmd: &Command) -> Result { self.as_ref().exec_cmd(cmd).await } async fn stream_files( &self, path: &Path, extensions: Option>, ) -> Result> { self.as_ref().stream_files(path, extensions).await } } #[derive(Debug, Error)] pub enum CommandError { /// The executor itself failed #[error("executor error: {0:#}")] ExecutorError(#[from] anyhow::Error), /// The command exceeded its allotted time budget #[error("command timed out after {timeout:?}: {output}")] TimedOut { timeout: Duration, output: CommandOutput, }, /// The command failed, i.e. failing tests with stderr. This error might be handled #[error("command failed with NonZeroExit: {0}")] NonZeroExit(CommandOutput), } impl From for CommandError { fn from(err: std::io::Error) -> Self { CommandError::NonZeroExit(err.to_string().into()) } } /// Commands that can be executed by the executor /// Conceptually, `Shell` allows any kind of input, and other commands enable more optimized /// implementations. /// /// There is an ongoing consideration to make this an associated type on the executor /// /// TODO: Should be able to borrow everything? /// /// Use the constructor helpers (e.g. [`Command::shell`]) and then chain configuration methods /// such as [`Command::with_current_dir`] or [`Command::current_dir`] for builder-style ergonomics. #[derive(Debug, Clone)] #[non_exhaustive] pub enum Command { Shell { command: String, current_dir: Option, timeout: Option, }, ReadFile { path: PathBuf, current_dir: Option, timeout: Option, }, WriteFile { path: PathBuf, content: String, current_dir: Option, timeout: Option, }, } impl Command { pub fn shell>(cmd: S) -> Self { Command::Shell { command: cmd.into(), current_dir: None, timeout: None, } } pub fn read_file>(path: P) -> Self { Command::ReadFile { path: path.into(), current_dir: None, timeout: None, } } pub fn write_file, S: Into>(path: P, content: S) -> Self { Command::WriteFile { path: path.into(), content: content.into(), current_dir: None, timeout: None, } } /// Override the working directory used when executing this command. /// /// Executors may interpret relative paths in the context of their own /// working directory. #[must_use] pub fn with_current_dir>(mut self, path: P) -> Self { self.current_dir(path); self } /// Override the working directory using the `std::process::Command` /// builder-lite style API. pub fn current_dir>(&mut self, path: P) -> &mut Self { let dir = Some(path.into()); match self { Command::Shell { current_dir, .. } | Command::ReadFile { current_dir, .. } | Command::WriteFile { current_dir, .. } => { *current_dir = dir; } } self } pub fn clear_current_dir(&mut self) -> &mut Self { match self { Command::Shell { current_dir, .. } | Command::ReadFile { current_dir, .. } | Command::WriteFile { current_dir, .. } => { *current_dir = None; } } self } pub fn current_dir_path(&self) -> Option<&Path> { match self { Command::Shell { current_dir, .. } | Command::ReadFile { current_dir, .. } | Command::WriteFile { current_dir, .. } => current_dir.as_deref(), } } /// Override the timeout used when executing this command. #[must_use] pub fn with_timeout(mut self, timeout: Duration) -> Self { self.timeout(timeout); self } /// Override the timeout using the builder-style API. pub fn timeout(&mut self, timeout: Duration) -> &mut Self { match self { Command::Shell { timeout: slot, .. } | Command::ReadFile { timeout: slot, .. } | Command::WriteFile { timeout: slot, .. } => { *slot = Some(timeout); } } self } /// Remove any timeout previously configured on this command. pub fn clear_timeout(&mut self) -> &mut Self { match self { Command::Shell { timeout, .. } | Command::ReadFile { timeout, .. } | Command::WriteFile { timeout, .. } => { *timeout = None; } } self } /// Returns the timeout associated with this command, if any. pub fn timeout_duration(&self) -> Option<&Duration> { match self { Command::Shell { timeout, .. } | Command::ReadFile { timeout, .. } | Command::WriteFile { timeout, .. } => timeout.as_ref(), } } } /// Output from a `Command` #[derive(Debug, Clone)] pub struct CommandOutput { pub output: String, // status_code: i32, // success: bool, } impl CommandOutput { pub fn empty() -> Self { CommandOutput { output: String::new(), } } pub fn new(output: impl Into) -> Self { CommandOutput { output: output.into(), } } pub fn is_empty(&self) -> bool { self.output.is_empty() } } impl std::fmt::Display for CommandOutput { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.output.fmt(f) } } impl> From for CommandOutput { fn from(value: T) -> Self { CommandOutput { output: value.into(), } } } impl AsRef for CommandOutput { fn as_ref(&self) -> &str { &self.output } } /// Feedback that can be given on a tool, i.e. with a human in the loop #[derive(Debug, Clone, Serialize, Deserialize, strum_macros::EnumIs)] #[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))] pub enum ToolFeedback { Approved { payload: Option }, Refused { payload: Option }, } impl ToolFeedback { pub fn approved() -> Self { ToolFeedback::Approved { payload: None } } pub fn refused() -> Self { ToolFeedback::Refused { payload: None } } pub fn payload(&self) -> Option<&serde_json::Value> { match self { ToolFeedback::Refused { payload } | ToolFeedback::Approved { payload } => { payload.as_ref() } } } #[must_use] pub fn with_payload(self, payload: serde_json::Value) -> Self { match self { ToolFeedback::Approved { .. } => ToolFeedback::Approved { payload: Some(payload), }, ToolFeedback::Refused { .. } => ToolFeedback::Refused { payload: Some(payload), }, } } } /// Acts as the interface to the external world and manages messages for completion #[async_trait] pub trait AgentContext: Send + Sync { /// List of all messages for this agent /// /// Used as main source for the next completion and expects all /// messages to be returned if new messages are present. /// /// Once this method has been called, there should not be new messages /// /// TODO: Figure out a nice way to return a reference instead while still supporting i.e. /// mutexes async fn next_completion(&self) -> Result>>; /// Lists only the new messages after calling `new_completion` async fn current_new_messages(&self) -> Result>; /// Add messages for the next completion async fn add_messages(&self, item: Vec) -> Result<()>; /// Add messages for the next completion async fn add_message(&self, item: ChatMessage) -> Result<()>; /// Execute a command if the context supports it /// /// Deprecated: use executor instead to access the executor directly #[deprecated(note = "use executor instead")] async fn exec_cmd(&self, cmd: &Command) -> Result; fn executor(&self) -> &Arc; async fn history(&self) -> Result>; /// Replace the entire history with the given items async fn replace_history(&self, items: Vec) -> Result<()>; /// Pops the last messages up until the last completion /// /// LLMs failing completion for various reasons is unfortunately a common occurrence /// This gives a way to redrive the last completion in a generic way async fn redrive(&self) -> Result<()>; /// Tools that require feedback or approval (i.e. from a human) can use this to check if the /// feedback is received async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option; async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()>; } #[async_trait] impl AgentContext for Box { async fn next_completion(&self) -> Result>> { (**self).next_completion().await } async fn current_new_messages(&self) -> Result> { (**self).current_new_messages().await } async fn add_messages(&self, item: Vec) -> Result<()> { (**self).add_messages(item).await } async fn add_message(&self, item: ChatMessage) -> Result<()> { (**self).add_message(item).await } #[allow(deprecated)] async fn exec_cmd(&self, cmd: &Command) -> Result { (**self).exec_cmd(cmd).await } fn executor(&self) -> &Arc { (**self).executor() } async fn history(&self) -> Result> { (**self).history().await } async fn replace_history(&self, items: Vec) -> Result<()> { (**self).replace_history(items).await } async fn redrive(&self) -> Result<()> { (**self).redrive().await } async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option { (**self).has_received_feedback(tool_call).await } async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> { (**self).feedback_received(tool_call, feedback).await } } #[async_trait] impl AgentContext for Arc { async fn next_completion(&self) -> Result>> { (**self).next_completion().await } async fn current_new_messages(&self) -> Result> { (**self).current_new_messages().await } async fn add_messages(&self, item: Vec) -> Result<()> { (**self).add_messages(item).await } async fn add_message(&self, item: ChatMessage) -> Result<()> { (**self).add_message(item).await } #[allow(deprecated)] async fn exec_cmd(&self, cmd: &Command) -> Result { (**self).exec_cmd(cmd).await } fn executor(&self) -> &Arc { (**self).executor() } async fn history(&self) -> Result> { (**self).history().await } async fn replace_history(&self, items: Vec) -> Result<()> { (**self).replace_history(items).await } async fn redrive(&self) -> Result<()> { (**self).redrive().await } async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option { (**self).has_received_feedback(tool_call).await } async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> { (**self).feedback_received(tool_call, feedback).await } } #[async_trait] impl AgentContext for &dyn AgentContext { async fn next_completion(&self) -> Result>> { (**self).next_completion().await } async fn current_new_messages(&self) -> Result> { (**self).current_new_messages().await } async fn add_messages(&self, item: Vec) -> Result<()> { (**self).add_messages(item).await } async fn add_message(&self, item: ChatMessage) -> Result<()> { (**self).add_message(item).await } #[allow(deprecated)] async fn exec_cmd(&self, cmd: &Command) -> Result { (**self).exec_cmd(cmd).await } fn executor(&self) -> &Arc { (**self).executor() } async fn history(&self) -> Result> { (**self).history().await } async fn replace_history(&self, items: Vec) -> Result<()> { (**self).replace_history(items).await } async fn redrive(&self) -> Result<()> { (**self).redrive().await } async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option { (**self).has_received_feedback(tool_call).await } async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> { (**self).feedback_received(tool_call, feedback).await } } /// Convenience implementation for empty agent context /// /// Errors if tools attempt to execute commands #[async_trait] impl AgentContext for () { async fn next_completion(&self) -> Result>> { Ok(None) } async fn current_new_messages(&self) -> Result> { Ok(Vec::new()) } async fn add_messages(&self, _item: Vec) -> Result<()> { Ok(()) } async fn add_message(&self, _item: ChatMessage) -> Result<()> { Ok(()) } async fn exec_cmd(&self, _cmd: &Command) -> Result { Err(CommandError::ExecutorError(anyhow::anyhow!( "Empty agent context does not have a tool executor" ))) } fn executor(&self) -> &Arc { unimplemented!("Empty agent context does not have a tool executor") } async fn history(&self) -> Result> { Ok(Vec::new()) } async fn replace_history(&self, _items: Vec) -> Result<()> { Ok(()) } async fn redrive(&self) -> Result<()> { Ok(()) } async fn has_received_feedback(&self, _tool_call: &ToolCall) -> Option { Some(ToolFeedback::Approved { payload: None }) } async fn feedback_received( &self, _tool_call: &ToolCall, _feedback: &ToolFeedback, ) -> Result<()> { Ok(()) } } /// A backend for the agent context. /// /// A default is provided for `Arc>>`. /// /// If you want to use for instance a database, implement this trait and pass it to the agent /// context when creating it. #[async_trait] pub trait MessageHistory: Send + Sync + std::fmt::Debug { /// Returns the history of messages async fn history(&self) -> Result>; /// Add a message to the history async fn push_owned(&self, item: ChatMessage) -> Result<()>; /// Overwrite the history with the given items async fn overwrite(&self, items: Vec) -> Result<()>; /// Add a message to the history. async fn push(&self, item: &ChatMessage) -> Result<()> { self.push_owned(item.to_owned()).await } /// Extend the history with the given items. async fn extend(&self, items: &[ChatMessage]) -> Result<()> { self.extend_owned(items.iter().map(ChatMessage::to_owned).collect()) .await } /// Extend the history with the given items, taking ownership of them async fn extend_owned(&self, items: Vec) -> Result<()> { for item in items { self.push_owned(item).await?; } Ok(()) } } #[async_trait] impl MessageHistory for Mutex> { async fn history(&self) -> Result> { Ok(self.lock().unwrap().clone()) } async fn push_owned(&self, item: ChatMessage) -> Result<()> { self.lock().unwrap().push(item); Ok(()) } async fn overwrite(&self, items: Vec) -> Result<()> { let mut lock = self.lock().unwrap(); *lock = items; Ok(()) } } ================================================ FILE: swiftide-core/src/chat_completion/chat_completion_request.rs ================================================ use std::{borrow::Cow, collections::BTreeSet}; use derive_builder::Builder; use super::{chat_message::ChatMessage, tools::ToolSpec, traits::Tool}; /// A chat completion request represents a series of chat messages and tool interactions that can /// be send to any LLM. #[derive(Builder, Clone, PartialEq, Debug)] #[builder(setter(into, strip_option))] pub struct ChatCompletionRequest<'a> { pub messages: Cow<'a, [ChatMessage]>, #[builder(default, setter(custom))] pub tools_spec: BTreeSet, } impl<'a> ChatCompletionRequest<'a> { pub fn builder() -> ChatCompletionRequestBuilder<'a> { ChatCompletionRequestBuilder::default() } /// Returns the chat messages included in the request. pub fn messages(&self) -> &[ChatMessage] { self.messages.as_ref() } /// Returns the tool specifications currently attached to the request. pub fn tools_spec(&self) -> &BTreeSet { &self.tools_spec } /// Returns an owned request with `'static` data. pub fn to_owned(&self) -> ChatCompletionRequest<'static> { ChatCompletionRequest { messages: Cow::Owned(self.messages.iter().map(ChatMessage::to_owned).collect()), tools_spec: self.tools_spec.clone(), } } } impl From> for ChatCompletionRequest<'_> { fn from(messages: Vec) -> Self { ChatCompletionRequest { messages: Cow::Owned(messages), tools_spec: BTreeSet::new(), } } } impl<'a> From<&'a [ChatMessage]> for ChatCompletionRequest<'a> { fn from(messages: &'a [ChatMessage]) -> Self { ChatCompletionRequest { messages: Cow::Borrowed(messages), tools_spec: BTreeSet::new(), } } } impl ChatCompletionRequestBuilder<'_> { #[deprecated(note = "Use `tools` with real Tool instances instead")] pub fn tools_spec(&mut self, tools_spec: I) -> &mut Self where I: IntoIterator, { self.tools_spec = Some(tools_spec.into_iter().collect()); self } /// Adds multiple tools by deriving their specs from the provided instances. pub fn tools(&mut self, tools: I) -> &mut Self where I: IntoIterator, T: Into>, { let specs = tools.into_iter().map(|tool| { let boxed: Box = tool.into(); boxed.tool_spec() }); self.tool_specs(specs) } /// Adds a single tool instance to the request by deriving its spec. pub fn tool(&mut self, tool: T) -> &mut Self where T: Into>, { let boxed: Box = tool.into(); self.tool_specs(std::iter::once(boxed.tool_spec())) } /// Extends the request with additional tool specifications. pub fn tool_specs(&mut self, specs: I) -> &mut Self where I: IntoIterator, { let entry = self.tools_spec.get_or_insert_with(BTreeSet::new); entry.extend(specs); self } /// Adds a single chat message to the request pub fn message(&mut self, message: impl Into) -> &mut Self { let mut messages = self .messages .take() .map(Cow::into_owned) .unwrap_or_default(); messages.push(message.into()); self.messages = Some(Cow::Owned(messages)); self } /// Extends the request with multiple chat messages. pub fn messages_iter(&mut self, messages: I) -> &mut Self where I: IntoIterator, { let mut new_messages = self .messages .take() .map(Cow::into_owned) .unwrap_or_default(); new_messages.extend(messages); self.messages = Some(Cow::Owned(new_messages)); self } } #[cfg(test)] mod tests { use super::ChatCompletionRequest; use crate::chat_completion::{ChatMessage, ToolSpec}; use schemars::Schema; use serde_json::json; #[test] fn tool_specs_are_stored_in_deterministic_order() { let zebra = ToolSpec::builder() .name("zebra") .description("later alphabetically") .parameters_schema(schema_from_json(json!({ "type": "object", "properties": { "b": { "type": "string" }, "a": { "type": "string" } } }))) .build() .unwrap(); let alpha = ToolSpec::builder() .name("alpha") .description("earlier alphabetically") .parameters_schema(schema_from_json(json!({ "properties": { "z": { "type": "string" }, "m": { "type": "string" } }, "type": "object" }))) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs([zebra, alpha]) .build() .unwrap(); let names = request .tools_spec() .iter() .map(|spec| spec.name.as_str()) .collect::>(); assert_eq!(names, vec!["alpha", "zebra"]); } fn schema_from_json(value: serde_json::Value) -> Schema { serde_json::from_value(value).expect("valid schema") } } ================================================ FILE: swiftide-core/src/chat_completion/chat_completion_response.rs ================================================ use std::collections::HashMap; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use uuid::Uuid; use super::{ReasoningItem, ToolCallBuilder, tools::ToolCall}; /// A generic response from chat completions /// /// When streaming, the delta is available. Every response will have the accumulated message if /// present. The final message will also have the final tool calls. #[derive(Clone, Builder, Debug, Serialize, Deserialize, PartialEq)] #[builder(setter(strip_option, into), build_fn(error = anyhow::Error))] pub struct ChatCompletionResponse { /// An identifier for the response /// /// Useful when streaming to make sure chunks can be mapped to the right response #[builder(private, default = Uuid::new_v4())] pub id: Uuid, #[builder(default)] pub message: Option, #[builder(default)] pub tool_calls: Option>, #[builder(default)] pub usage: Option, #[builder(default)] pub reasoning: Option>, /// Streaming response #[builder(default)] pub delta: Option, } impl Default for ChatCompletionResponse { fn default() -> Self { Self { id: Uuid::new_v4(), message: None, tool_calls: None, delta: None, usage: None, reasoning: None, } } } /// Usage statistics for a language model response. #[derive(Clone, Default, Builder, Debug, Serialize, Deserialize, PartialEq)] #[allow(clippy::struct_field_names)] pub struct Usage { /// Tokens used in the prompt or input. pub prompt_tokens: u32, /// Tokens generated in the completion or output. pub completion_tokens: u32, /// Total tokens used for the request. pub total_tokens: u32, /// Provider-specific usage breakdowns, when available. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] pub details: Option, } impl Usage { pub fn builder() -> UsageBuilder { UsageBuilder::default() } /// Returns a normalized view of usage details when available. /// /// This keeps the public `Usage` fields intact and derives a consistent input/output breakdown /// across providers (e.g. `OpenAI` chat vs. responses). Missing data is left as `None`. pub fn normalized(&self) -> NormalizedUsage { let details = self.details.as_ref().map(|details| { let input = NormalizedInputUsageDetails { cached_tokens: details .input_tokens_details .as_ref() .and_then(|input| input.cached_tokens) .or_else(|| { details .prompt_tokens_details .as_ref() .and_then(|prompt| prompt.cached_tokens) }), audio_tokens: details .prompt_tokens_details .as_ref() .and_then(|prompt| prompt.audio_tokens), }; let output = NormalizedOutputUsageDetails { reasoning_tokens: details .output_tokens_details .as_ref() .and_then(|output| output.reasoning_tokens) .or_else(|| { details .completion_tokens_details .as_ref() .and_then(|completion| completion.reasoning_tokens) }), audio_tokens: details .completion_tokens_details .as_ref() .and_then(|completion| completion.audio_tokens), accepted_prediction_tokens: details .completion_tokens_details .as_ref() .and_then(|completion| completion.accepted_prediction_tokens), rejected_prediction_tokens: details .completion_tokens_details .as_ref() .and_then(|completion| completion.rejected_prediction_tokens), }; if input.is_empty() && output.is_empty() { None } else { Some(NormalizedUsageDetails { input, output }) } }); NormalizedUsage { prompt_tokens: self.prompt_tokens, completion_tokens: self.completion_tokens, total_tokens: self.total_tokens, details: details.flatten(), } } } /// Provider-specific usage breakdowns for a response. #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct UsageDetails { /// Chat-completions style prompt token details. #[serde(skip_serializing_if = "Option::is_none")] pub prompt_tokens_details: Option, /// Chat-completions style completion token details. #[serde(skip_serializing_if = "Option::is_none")] pub completion_tokens_details: Option, /// Responses-style input token details. #[serde(skip_serializing_if = "Option::is_none")] pub input_tokens_details: Option, /// Responses-style output token details. #[serde(skip_serializing_if = "Option::is_none")] pub output_tokens_details: Option, } /// Normalized usage totals with optional normalized details. #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct NormalizedUsage { /// Tokens used in the prompt or input. pub prompt_tokens: u32, /// Tokens generated in the completion or output. pub completion_tokens: u32, /// Total tokens used for the request. pub total_tokens: u32, /// Normalized input/output breakdown, when available. #[serde(skip_serializing_if = "Option::is_none")] pub details: Option, } /// Normalized input/output usage breakdown. #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct NormalizedUsageDetails { /// Normalized input usage details. pub input: NormalizedInputUsageDetails, /// Normalized output usage details. pub output: NormalizedOutputUsageDetails, } /// Normalized input usage details. #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct NormalizedInputUsageDetails { /// Tokens retrieved from cache, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub cached_tokens: Option, /// Audio tokens in the input, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub audio_tokens: Option, } impl NormalizedInputUsageDetails { fn is_empty(&self) -> bool { self.cached_tokens.is_none() && self.audio_tokens.is_none() } } /// Normalized output usage details. #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct NormalizedOutputUsageDetails { /// Tokens used for reasoning, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_tokens: Option, /// Audio tokens in the output, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub audio_tokens: Option, /// Accepted prediction tokens, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub accepted_prediction_tokens: Option, /// Rejected prediction tokens, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub rejected_prediction_tokens: Option, } impl NormalizedOutputUsageDetails { fn is_empty(&self) -> bool { self.reasoning_tokens.is_none() && self.audio_tokens.is_none() && self.accepted_prediction_tokens.is_none() && self.rejected_prediction_tokens.is_none() } } /// OpenAI-style prompt token details (chat completions). #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct PromptTokensDetails { /// Audio input tokens present in the prompt. #[serde(skip_serializing_if = "Option::is_none")] pub audio_tokens: Option, /// Cached tokens present in the prompt. #[serde(skip_serializing_if = "Option::is_none")] pub cached_tokens: Option, } impl PromptTokensDetails { /// Returns true when no prompt token detail values are set. pub fn is_empty(&self) -> bool { self.audio_tokens.is_none() && self.cached_tokens.is_none() } } /// OpenAI-style completion token details (chat completions). #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct CompletionTokensDetails { /// Tokens accepted from predicted output, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub accepted_prediction_tokens: Option, /// Audio tokens generated by the model, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub audio_tokens: Option, /// Tokens generated by the model for reasoning, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_tokens: Option, /// Tokens rejected from predicted output, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub rejected_prediction_tokens: Option, } impl CompletionTokensDetails { /// Returns true when no completion token detail values are set. pub fn is_empty(&self) -> bool { self.accepted_prediction_tokens.is_none() && self.audio_tokens.is_none() && self.reasoning_tokens.is_none() && self.rejected_prediction_tokens.is_none() } } /// OpenAI-style input token details (Responses API). #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct InputTokenDetails { /// Tokens retrieved from cache, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub cached_tokens: Option, } /// OpenAI-style output token details (Responses API). #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)] pub struct OutputTokenDetails { /// Tokens used for reasoning, when provided. #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_tokens: Option, } #[cfg(feature = "openai")] mod openai_usage { use super::{ CompletionTokensDetails, InputTokenDetails, OutputTokenDetails, PromptTokensDetails, Usage, UsageDetails, }; use async_openai::types::{ chat::CompletionUsage, embeddings::EmbeddingUsage, responses::ResponseUsage, }; impl From<&CompletionUsage> for Usage { fn from(usage: &CompletionUsage) -> Self { let prompt_details = usage.prompt_tokens_details.as_ref().and_then(|details| { let details = PromptTokensDetails { audio_tokens: details.audio_tokens, cached_tokens: details.cached_tokens, }; if details.is_empty() { None } else { Some(details) } }); let completion_details = usage .completion_tokens_details .as_ref() .and_then(|details| { let details = CompletionTokensDetails { accepted_prediction_tokens: details.accepted_prediction_tokens, audio_tokens: details.audio_tokens, reasoning_tokens: details.reasoning_tokens, rejected_prediction_tokens: details.rejected_prediction_tokens, }; if details.is_empty() { None } else { Some(details) } }); let details = if prompt_details.is_some() || completion_details.is_some() { Some(UsageDetails { prompt_tokens_details: prompt_details, completion_tokens_details: completion_details, input_tokens_details: None, output_tokens_details: None, }) } else { None }; Usage { prompt_tokens: usage.prompt_tokens, completion_tokens: usage.completion_tokens, total_tokens: usage.total_tokens, details, } } } impl From<&ResponseUsage> for Usage { fn from(usage: &ResponseUsage) -> Self { Usage { prompt_tokens: usage.input_tokens, completion_tokens: usage.output_tokens, total_tokens: usage.total_tokens, details: Some(UsageDetails { prompt_tokens_details: None, completion_tokens_details: None, input_tokens_details: Some(InputTokenDetails { cached_tokens: Some(usage.input_tokens_details.cached_tokens), }), output_tokens_details: Some(OutputTokenDetails { reasoning_tokens: Some(usage.output_tokens_details.reasoning_tokens), }), }), } } } impl From<&EmbeddingUsage> for Usage { fn from(usage: &EmbeddingUsage) -> Self { Usage { prompt_tokens: usage.prompt_tokens, completion_tokens: 0, total_tokens: usage.total_tokens, details: None, } } } } #[derive(Clone, Builder, Debug, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionResponseDelta { #[builder(default)] pub message_chunk: Option, #[builder(default)] pub tool_calls_chunk: Option>, } // Accumulator for streamed tool calls #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct ToolCallAccum { pub id: Option, pub name: Option, pub arguments: Option, } impl ChatCompletionResponse { pub fn builder() -> ChatCompletionResponseBuilder { ChatCompletionResponseBuilder::default() } pub fn message(&self) -> Option<&str> { self.message.as_deref() } pub fn tool_calls(&self) -> Option<&[ToolCall]> { self.tool_calls.as_deref() } /// Adds a streaming chunk to the message and also the delta pub fn append_message_delta(&mut self, message_delta: Option<&str>) -> &mut Self { // let message: Option = message; let Some(message_delta) = message_delta else { return self; }; if let Some(delta) = &mut self.delta { delta.message_chunk = Some(message_delta.to_string()); } else { self.delta = Some(ChatCompletionResponseDelta { message_chunk: Some(message_delta.to_string()), tool_calls_chunk: None, }); } self.message .as_mut() .map(|m| { m.push_str(message_delta); }) .unwrap_or_else(|| { self.message = Some(message_delta.to_string()); }); self } /// Adds a streaming chunk to the tool calls, if it can be build, the tool call will be build, /// otherwise it will remain in the delta and retried on the next call pub fn append_tool_call_delta( &mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>, ) -> &mut Self { if let Some(delta) = &mut self.delta { let map = delta.tool_calls_chunk.get_or_insert_with(HashMap::new); map.entry(index) .and_modify(|v| { if v.id.is_none() { v.id = id.map(Into::into); } if v.name.is_none() { v.name = name.map(Into::into); } if let Some(v) = v.arguments.as_mut() { if let Some(arguments) = arguments { v.push_str(arguments); } } else { v.arguments = arguments.map(Into::into); } }) .or_insert(ToolCallAccum { id: id.map(Into::into), name: name.map(Into::into), arguments: arguments.map(Into::into), }); } else { self.delta = Some(ChatCompletionResponseDelta { message_chunk: None, tool_calls_chunk: Some(HashMap::from([( index, ToolCallAccum { id: id.map(Into::into), name: name.map(Into::into), arguments: arguments.map(Into::into), }, )])), }); } // Now let's try to rebuild _every_ tool call and overwrite // Performance wise very meh but it works, in reality it's only a couple of tool calls most self.finalize_tools_from_stream(); self } pub fn append_usage_delta( &mut self, prompt_tokens: u32, completion_tokens: u32, total_tokens: u32, ) -> &mut Self { debug_assert!(prompt_tokens + completion_tokens == total_tokens); if let Some(usage) = &mut self.usage { usage.prompt_tokens += prompt_tokens; usage.completion_tokens += completion_tokens; usage.total_tokens += total_tokens; } else { self.usage = Some(Usage { prompt_tokens, completion_tokens, total_tokens, details: None, }); } self } fn finalize_tools_from_stream(&mut self) { if let Some(values) = self .delta .as_ref() .and_then(|d| d.tool_calls_chunk.as_ref().map(|t| t.values())) { let maybe_tool_calls = values .filter_map(|maybe_tool_call| { ToolCallBuilder::default() .maybe_id(maybe_tool_call.id.clone()) .maybe_name(maybe_tool_call.name.clone()) .maybe_args(maybe_tool_call.arguments.clone()) .build() .ok() }) .collect::>(); if !maybe_tool_calls.is_empty() { self.tool_calls = Some(maybe_tool_calls); } } } } impl ChatCompletionResponseBuilder { pub fn maybe_message>>(&mut self, message: T) -> &mut Self { self.message = Some(message.into()); self } pub fn maybe_tool_calls>>>(&mut self, tool_calls: T) -> &mut Self { self.tool_calls = Some(tool_calls.into()); self } } ================================================ FILE: swiftide-core/src/chat_completion/chat_message.rs ================================================ use std::borrow::Cow; use serde::{Deserialize, Serialize}; use super::tools::{ToolCall, ToolOutput}; /// Reasoning items returned by chat providers that expose chain-of-thought metadata. #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Default)] pub struct ReasoningItem { /// Unique identifier for this reasoning item pub id: String, /// Reasoning summary content pub summary: Vec, /// Reasoning text content #[serde(default, skip_serializing_if = "Option::is_none")] pub content: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub encrypted_content: Option, /// The status of the item. One of `in_progress`, `completed`, or `incomplete`. /// Populated when items are returned via API. #[serde(skip_serializing_if = "Option::is_none")] pub status: Option, } #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] #[serde(rename_all = "snake_case")] pub enum ReasoningStatus { InProgress, Completed, Incomplete, } #[derive(Clone, PartialEq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ChatMessageContentSource { Url { url: String, }, Bytes { data: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] media_type: Option, }, S3 { uri: String, #[serde(default, skip_serializing_if = "Option::is_none")] bucket_owner: Option, }, FileId { file_id: String, }, } impl ChatMessageContentSource { pub fn url(url: impl Into) -> Self { Self::Url { url: url.into() } } pub fn bytes(data: impl Into>, media_type: Option) -> Self where M: Into, { Self::Bytes { data: data.into(), media_type: media_type.map(Into::into), } } pub fn s3(uri: impl Into, bucket_owner: Option) -> Self where O: Into, { Self::S3 { uri: uri.into(), bucket_owner: bucket_owner.map(Into::into), } } pub fn file_id(file_id: impl Into) -> Self { Self::FileId { file_id: file_id.into(), } } } impl From for ChatMessageContentSource { fn from(value: String) -> Self { Self::Url { url: value } } } impl From<&str> for ChatMessageContentSource { fn from(value: &str) -> Self { Self::Url { url: value.to_owned(), } } } impl From> for ChatMessageContentSource { fn from(value: Vec) -> Self { Self::Bytes { data: value, media_type: None, } } } impl From<&[u8]> for ChatMessageContentSource { fn from(value: &[u8]) -> Self { Self::Bytes { data: value.to_vec(), media_type: None, } } } impl std::fmt::Debug for ChatMessageContentSource { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ChatMessageContentSource::Url { url } => f .debug_struct("Url") .field("url", &truncate_data_url(url)) .finish(), ChatMessageContentSource::Bytes { data, media_type } => f .debug_struct("Bytes") .field("len", &data.len()) .field("media_type", media_type) .finish(), ChatMessageContentSource::S3 { uri, bucket_owner } => f .debug_struct("S3") .field("uri", uri) .field("bucket_owner", bucket_owner) .finish(), ChatMessageContentSource::FileId { file_id } => { f.debug_struct("FileId").field("file_id", file_id).finish() } } } } #[derive(Clone, PartialEq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ChatMessageContentPart { Text { text: String, }, Image { source: ChatMessageContentSource, #[serde(default, skip_serializing_if = "Option::is_none")] format: Option, }, Document { source: ChatMessageContentSource, #[serde(default, skip_serializing_if = "Option::is_none")] format: Option, #[serde(default, skip_serializing_if = "Option::is_none")] name: Option, }, Audio { source: ChatMessageContentSource, #[serde(default, skip_serializing_if = "Option::is_none")] format: Option, }, Video { source: ChatMessageContentSource, #[serde(default, skip_serializing_if = "Option::is_none")] format: Option, }, } impl ChatMessageContentPart { pub fn text(text: impl Into) -> Self { Self::Text { text: text.into() } } pub fn image(source: impl Into) -> Self { Self::Image { source: source.into(), format: None, } } pub fn image_with_format( source: impl Into, format: impl Into, ) -> Self { Self::Image { source: source.into(), format: Some(format.into()), } } pub fn document(source: impl Into) -> Self { Self::Document { source: source.into(), format: None, name: None, } } pub fn document_with_name( source: impl Into, name: impl Into, ) -> Self { Self::Document { source: source.into(), format: None, name: Some(name.into()), } } pub fn audio(source: impl Into) -> Self { Self::Audio { source: source.into(), format: None, } } pub fn video(source: impl Into) -> Self { Self::Video { source: source.into(), format: None, } } } impl std::fmt::Debug for ChatMessageContentPart { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ChatMessageContentPart::Text { text } => { f.debug_struct("Text").field("text", text).finish() } ChatMessageContentPart::Image { source, format } => f .debug_struct("Image") .field("source", source) .field("format", format) .finish(), ChatMessageContentPart::Document { source, format, name, } => f .debug_struct("Document") .field("source", source) .field("format", format) .field("name", name) .finish(), ChatMessageContentPart::Audio { source, format } => f .debug_struct("Audio") .field("source", source) .field("format", format) .finish(), ChatMessageContentPart::Video { source, format } => f .debug_struct("Video") .field("source", source) .field("format", format) .finish(), } } } #[derive(Clone, strum_macros::EnumIs, PartialEq, Debug, Serialize, Deserialize)] pub enum ChatMessage { System(String), User(String), UserWithParts(Vec), Assistant(Option, Option>), ToolOutput(ToolCall, ToolOutput), Reasoning(ReasoningItem), // A summary of the chat. If encountered all previous messages are ignored, except the system // prompt Summary(String), } impl std::fmt::Display for ChatMessage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::System(s) => write!(f, "System: \"{s}\""), Self::User(s) => write!(f, "User: \"{s}\""), Self::UserWithParts(parts) => { let (text, attachments) = summarize_user_parts(parts); if attachments == 0 { write!(f, "User: \"{text}\"") } else { write!(f, "User: \"{text}\", attachments: {attachments}") } } Self::Assistant(content, tool_calls) => write!( f, "Assistant: \"{}\", tools: {}", content.as_deref().unwrap_or("None"), tool_calls.as_deref().map_or("None".to_string(), |tc| { tc.iter() .map(ToString::to_string) .collect::>() .join(", ") }) ), Self::ToolOutput(tc, to) => write!(f, "ToolOutput: \"{tc}\": \"{to}\""), Self::Reasoning(item) => write!( f, "Reasoning: \"{}\", encrypted: {}", item.summary.join("\n"), item.encrypted_content.is_some() ), Self::Summary(s) => write!(f, "Summary: \"{s}\""), } } } impl ChatMessage { pub fn new_system(message: impl Into) -> Self { Self::System(message.into()) } pub fn new_user(message: impl Into) -> Self { Self::User(message.into()) } pub fn new_user_with_parts(parts: impl Into>) -> Self { Self::UserWithParts(parts.into()) } pub fn new_assistant( message: Option>, tool_calls: Option>, ) -> Self { Self::Assistant(message.map(Into::into), tool_calls) } pub fn new_tool_output(tool_call: impl Into, output: impl Into) -> Self { Self::ToolOutput(tool_call.into(), output.into()) } pub fn new_reasoning(message: ReasoningItem) -> Self { Self::Reasoning(message) } pub fn new_summary(message: impl Into) -> Self { Self::Summary(message.into()) } #[must_use] pub fn to_owned(&self) -> Self { self.clone() } } /// Returns the content of the message as a string slice. /// /// Note that this omits the tool calls from the assistant message. /// /// If used for estimating tokens, consider this a very rought estimate impl AsRef for ChatMessage { fn as_ref(&self) -> &str { match self { Self::System(s) | Self::User(s) | Self::Summary(s) => s, Self::UserWithParts(parts) => match parts.as_slice() { [ChatMessageContentPart::Text { text }] => text.as_ref(), _ => "", }, Self::Assistant(message, _) => message.as_deref().unwrap_or(""), Self::ToolOutput(_, output) => output.content().unwrap_or(""), Self::Reasoning(_) => "", } } } fn summarize_user_parts(parts: &[ChatMessageContentPart]) -> (String, usize) { let mut text_parts = Vec::new(); let mut attachments = 0; for part in parts { match part { ChatMessageContentPart::Text { text } => text_parts.push(text.as_str()), ChatMessageContentPart::Image { .. } | ChatMessageContentPart::Document { .. } | ChatMessageContentPart::Audio { .. } | ChatMessageContentPart::Video { .. } => attachments += 1, } } (text_parts.join(" "), attachments) } fn truncate_data_url(url: &str) -> Cow<'_, str> { const MAX_DATA_PREVIEW: usize = 32; if !url.starts_with("data:") { return Cow::Borrowed(url); } let Some((prefix, data)) = url.split_once(',') else { return Cow::Borrowed(url); }; if data.len() <= MAX_DATA_PREVIEW { return Cow::Borrowed(url); } let preview = &data[..MAX_DATA_PREVIEW]; let truncated = data.len() - MAX_DATA_PREVIEW; Cow::Owned(format!( "{prefix},{preview}...[truncated {truncated} chars]" )) } ================================================ FILE: swiftide-core/src/chat_completion/errors.rs ================================================ use std::borrow::Cow; use thiserror::Error; use crate::CommandError; use super::ChatCompletionStream; /// A `ToolError` is an error that occurs when a tool is invoked. /// /// Depending on the agent configuration, the tool might be retried with feedback to the LLM, up to /// a limit. #[derive(Error, Debug)] pub enum ToolError { /// I.e. the llm calls the tool with the wrong arguments #[error("arguments for tool failed to parse: {0:#}")] WrongArguments(#[from] serde_json::Error), /// Tool requires arguments but none were provided #[error("arguments missing for tool {0:#}")] MissingArguments(Cow<'static, str>), /// Tool execution failed #[error("tool execution failed: {0:#}")] ExecutionFailed(#[from] CommandError), #[error(transparent)] Unknown(#[from] anyhow::Error), } impl ToolError { /// Tool received arguments that it could not parse pub fn wrong_arguments(e: impl Into) -> Self { ToolError::WrongArguments(e.into()) } /// Tool is missing required arguments pub fn missing_arguments(tool_name: impl Into>) -> Self { ToolError::MissingArguments(tool_name.into()) } /// Tool execution failed pub fn execution_failed(e: impl Into) -> Self { ToolError::ExecutionFailed(e.into()) } /// Tool failed with an unknown error pub fn unknown(e: impl Into) -> Self { ToolError::Unknown(e.into()) } } type BoxedError = Box; #[derive(Error, Debug)] pub enum LanguageModelError { #[error("Context length exceeded: {0:#}")] ContextLengthExceeded(BoxedError), #[error("Permanent error: {0:#}")] PermanentError(BoxedError), #[error("Transient error: {0:#}")] TransientError(BoxedError), } impl LanguageModelError { pub fn permanent(e: impl Into) -> Self { LanguageModelError::PermanentError(e.into()) } pub fn transient(e: impl Into) -> Self { LanguageModelError::TransientError(e.into()) } pub fn context_length_exceeded(e: impl Into) -> Self { LanguageModelError::ContextLengthExceeded(e.into()) } } impl From for LanguageModelError { fn from(e: BoxedError) -> Self { LanguageModelError::PermanentError(e) } } impl From for LanguageModelError { fn from(e: anyhow::Error) -> Self { LanguageModelError::PermanentError(e.into()) } } // Make it easier to use the error in streaming functions impl From for ChatCompletionStream { fn from(val: LanguageModelError) -> Self { Box::pin(futures_util::stream::once(async move { Err(val) })) } } ================================================ FILE: swiftide-core/src/chat_completion/mod.rs ================================================ //! This module enables the implementation of chat completion on LLM providers //! //! The main trait to implement is `ChatCompletion`, which takes a `ChatCompletionRequest` and //! returns a `ChatCompletionResponse`. //! //! A chat completion request is comprised of a list of `ChatMessage` to complete, with //! optionally tool specifications. The builder accepts either owned or borrowed messages and //! provides `tools(...)` while still exposing `tool_specs` for compatibility. mod chat_completion_request; mod chat_completion_response; mod chat_message; pub mod errors; mod tool_schema; mod tools; // Re-exported in the root per convention pub mod traits; pub use chat_completion_request::*; pub use chat_completion_response::*; pub use chat_message::*; pub use tools::*; pub use traits::*; ================================================ FILE: swiftide-core/src/chat_completion/tool_schema.rs ================================================ use schemars::Schema; use serde_json::{Map, Value, json}; use thiserror::Error; #[derive(Clone, Debug, PartialEq)] pub struct StrictToolParametersSchema { document: Value, } #[derive(Debug, Error)] pub enum ToolSchemaError { #[error("failed to serialize tool schema")] SerializeSchema(#[from] serde_json::Error), #[error("tool schema must be a JSON object")] RootMustBeObject, #[error("tool schema node at {path} must be a JSON object")] NodeMustBeObject { path: String }, #[error("tool schema map at {path} must be a JSON object")] NodeMapMustBeObject { path: String }, #[error("tool schema required must be an array at {path}")] RequiredMustBeArray { path: String }, #[error( "strict tool schemas do not support patternProperties at {path}; define explicit properties instead" )] PatternPropertiesUnsupported { path: String }, #[error( "strict tool schemas do not support propertyNames at {path}; define explicit properties instead" )] PropertyNamesUnsupported { path: String }, #[error( "strict tool schemas do not support open object schemas at {path}; define explicit properties instead" )] OpenObjectUnsupported { path: String }, #[error( "strict tool schemas do not support schema-valued additionalProperties at {path}; define explicit properties instead" )] SchemaValuedAdditionalPropertiesUnsupported { path: String }, #[error("strict tool schemas do not support {kind}-valued additionalProperties at {path}")] InvalidAdditionalProperties { path: String, kind: &'static str }, #[error("strict tool schemas do not support $ref siblings {keywords} at {path}")] UnsupportedRefSiblingKeywords { path: String, keywords: String }, } impl StrictToolParametersSchema { pub(super) fn try_from_raw(schema: Option<&Schema>) -> Result { let raw = match schema { Some(schema) => serde_json::to_value(schema)?, None => json!({}), }; let root = raw.as_object().ok_or(ToolSchemaError::RootMustBeObject)?; Ok(Self { document: Value::Object(parse_schema_object(root, &SchemaPath::root(), true)?), }) } pub fn into_json(self) -> Value { self.document } pub fn as_json(&self) -> &Value { &self.document } } fn parse_schema_value(value: &Value, path: &SchemaPath) -> Result { let object = value .as_object() .ok_or_else(|| ToolSchemaError::NodeMustBeObject { path: path.to_string(), })?; Ok(Value::Object(parse_schema_object(object, path, false)?)) } fn parse_schema_object( schema: &Map, path: &SchemaPath, force_object: bool, ) -> Result, ToolSchemaError> { let schema = normalize_schema_object(schema, path)?; if force_object || schema_is_object(&schema) { parse_object_schema(&schema, path) } else { parse_non_object_schema(&schema, path) } } fn normalize_schema_object( schema: &Map, path: &SchemaPath, ) -> Result, ToolSchemaError> { let mut normalized = schema.clone(); rewrite_nullable_type_union(&mut normalized); rewrite_nullable_one_of(&mut normalized); strip_ref_annotation_siblings(&mut normalized, path)?; Ok(normalized) } fn rewrite_nullable_type_union(schema: &mut Map) { let Some(entries) = schema.get("type").and_then(Value::as_array) else { return; }; let Some(non_null_type) = nullable_type_union(entries).map(str::to_owned) else { return; }; let mut non_null_branch = schema.clone(); non_null_branch.insert("type".to_string(), Value::String(non_null_type)); let annotations = extract_schema_annotations(schema); for key in schema_annotation_keys() { non_null_branch.remove(*key); } schema.clear(); schema.extend(annotations); schema.insert( "anyOf".to_string(), Value::Array(vec![ Value::Object(non_null_branch), json!({ "type": "null" }), ]), ); } fn rewrite_nullable_one_of(schema: &mut Map) { let Some(entries) = schema.get("oneOf").and_then(Value::as_array).cloned() else { return; }; if is_nullable_union(&entries) { schema.remove("oneOf"); schema.insert("anyOf".to_string(), Value::Array(entries)); } } fn is_nullable_union(entries: &[Value]) -> bool { entries.len() == 2 && entries.iter().any(is_null_schema) } fn nullable_type_union(entries: &[Value]) -> Option<&str> { if entries.len() != 2 { return None; } let mut non_null = None; for entry in entries { let kind = entry.as_str()?; if kind == "null" { continue; } if non_null.is_some() { return None; } non_null = Some(kind); } non_null } fn is_null_schema(value: &Value) -> bool { matches!( value, Value::Object(object) if matches!(object.get("type"), Some(Value::String(kind)) if kind == "null") ) } fn extract_schema_annotations(schema: &Map) -> Map { schema_annotation_keys() .iter() .filter_map(|key| { schema .get(*key) .cloned() .map(|value| ((*key).to_string(), value)) }) .collect() } fn schema_annotation_keys() -> &'static [&'static str] { &[ "description", "title", "default", "examples", "deprecated", "readOnly", "writeOnly", ] } fn strip_ref_annotation_siblings( schema: &mut Map, path: &SchemaPath, ) -> Result<(), ToolSchemaError> { const SAFE_REF_ANNOTATIONS: &[&str] = &[ "description", "title", "default", "examples", "deprecated", "readOnly", "writeOnly", ]; if !schema.contains_key("$ref") { return Ok(()); } let mut unsupported = Vec::new(); let sibling_keys = schema .keys() .filter(|key| key.as_str() != "$ref") .cloned() .collect::>(); for key in sibling_keys { if SAFE_REF_ANNOTATIONS.contains(&key.as_str()) { schema.remove(&key); } else { unsupported.push(key); } } if unsupported.is_empty() { Ok(()) } else { Err(ToolSchemaError::UnsupportedRefSiblingKeywords { path: path.to_string(), keywords: unsupported.join(", "), }) } } fn parse_object_schema( schema: &Map, path: &SchemaPath, ) -> Result, ToolSchemaError> { if schema.get("patternProperties").is_some() { return Err(ToolSchemaError::PatternPropertiesUnsupported { path: path.to_string(), }); } if schema.get("propertyNames").is_some() { return Err(ToolSchemaError::PropertyNamesUnsupported { path: path.to_string(), }); } match schema.get("additionalProperties") { Some(Value::Bool(true)) => { return Err(ToolSchemaError::OpenObjectUnsupported { path: path.to_string(), }); } Some(Value::Object(_)) => { return Err( ToolSchemaError::SchemaValuedAdditionalPropertiesUnsupported { path: path.to_string(), }, ); } Some(Value::Array(_)) => { return Err(ToolSchemaError::InvalidAdditionalProperties { path: path.to_string(), kind: "array", }); } Some(Value::Null) => { return Err(ToolSchemaError::InvalidAdditionalProperties { path: path.to_string(), kind: "null", }); } Some(Value::String(_) | Value::Number(_)) => { return Err(ToolSchemaError::InvalidAdditionalProperties { path: path.to_string(), kind: "scalar", }); } Some(Value::Bool(false)) | None => {} } let mut parsed = schema.clone(); parsed.insert("type".to_string(), Value::String("object".to_string())); parsed.insert("additionalProperties".to_string(), Value::Bool(false)); parsed.insert( "properties".to_string(), Value::Object(parse_schema_map( schema.get("properties"), &path.with_key("properties"), )?), ); if let Some(required) = schema.get("required") && !required.is_array() { return Err(ToolSchemaError::RequiredMustBeArray { path: path.with_key("required").to_string(), }); } recurse_schema_children(schema, &mut parsed, path)?; Ok(parsed) } fn parse_non_object_schema( schema: &Map, path: &SchemaPath, ) -> Result, ToolSchemaError> { let mut parsed = schema.clone(); recurse_schema_children(schema, &mut parsed, path)?; Ok(parsed) } fn recurse_schema_children( source: &Map, target: &mut Map, path: &SchemaPath, ) -> Result<(), ToolSchemaError> { for key in ["items", "contains", "if", "then", "else", "not"] { if let Some(schema) = source.get(key) { target.insert( key.to_string(), parse_schema_value(schema, &path.with_key(key))?, ); } } for key in ["anyOf", "oneOf", "allOf", "prefixItems"] { if let Some(entries) = source.get(key).and_then(Value::as_array) { target.insert( key.to_string(), Value::Array( entries .iter() .enumerate() .map(|(index, schema)| { parse_schema_value(schema, &path.with_index(key, index)) }) .collect::, _>>()?, ), ); } } for key in ["properties", "$defs", "definitions", "dependentSchemas"] { if let Some(entries) = source.get(key) { target.insert( key.to_string(), Value::Object(parse_schema_map(Some(entries), &path.with_key(key))?), ); } } Ok(()) } fn parse_schema_map( value: Option<&Value>, path: &SchemaPath, ) -> Result, ToolSchemaError> { let Some(value) = value else { return Ok(Map::new()); }; let entries = value .as_object() .ok_or_else(|| ToolSchemaError::NodeMapMustBeObject { path: path.to_string(), })?; let mut parsed = Map::new(); for (key, schema) in entries { parsed.insert( key.clone(), parse_schema_value(schema, &path.with_key(key))?, ); } Ok(parsed) } fn schema_is_object(schema: &Map) -> bool { type_includes_object(schema.get("type")) || schema.contains_key("properties") || schema.contains_key("additionalProperties") || schema.contains_key("patternProperties") || schema.contains_key("propertyNames") } fn type_includes_object(value: Option<&Value>) -> bool { match value { Some(Value::String(kind)) => kind == "object", Some(Value::Array(kinds)) => kinds .iter() .filter_map(Value::as_str) .any(|kind| kind == "object"), _ => false, } } #[derive(Clone, Debug)] pub(super) struct SchemaPath(Vec); impl SchemaPath { fn root() -> Self { Self(vec!["$".to_string()]) } fn with_key(&self, key: impl Into) -> Self { let mut path = self.0.clone(); path.push(key.into()); Self(path) } fn with_index(&self, key: impl Into, index: usize) -> Self { let mut path = self.0.clone(); path.push(key.into()); path.push(index.to_string()); Self(path) } } impl std::fmt::Display for SchemaPath { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.join(".")) } } #[cfg(test)] mod tests { use super::*; #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentArgs { request: NestedCommentRequest, } #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentRequest { #[serde(default, skip_serializing_if = "Option::is_none")] body: Option, #[serde(default, skip_serializing_if = "Option::is_none")] text: Option, #[serde(default, skip_serializing_if = "Option::is_none")] page_id: Option, #[serde(default, skip_serializing_if = "Option::is_none")] block_id: Option, #[serde(default, skip_serializing_if = "Option::is_none")] discussion_id: Option, } #[derive(serde::Serialize, serde::Deserialize)] #[serde(transparent)] struct FreeformObject(serde_json::Map); impl schemars::JsonSchema for FreeformObject { fn schema_name() -> std::borrow::Cow<'static, str> { "FreeformObject".into() } fn json_schema(_generator: &mut schemars::SchemaGenerator) -> Schema { serde_json::from_value(json!({ "type": "object", "additionalProperties": true })) .expect("freeform object schema should serialize") } } #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct CreateViewArgs { request: CreateViewRequest, } #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct CreateViewRequest { body: FreeformObject, } #[test] fn strict_tool_schema_rejects_nested_freeform_object_wrappers() { let error = StrictToolParametersSchema::try_from_raw(Some(&schemars::schema_for!(CreateViewArgs))) .expect_err("freeform object should be rejected in strict mode"); let message = error.to_string(); assert!(message.contains("strict tool schemas do not support open object schemas")); assert!(message.contains("FreeformObject")); } #[test] fn strict_tool_schema_rewrites_nullable_type_unions_to_any_of() { let schema: Schema = serde_json::from_value(json!({ "type": "object", "properties": { "body": { "type": ["string", "null"] } }, "required": ["body"] })) .expect("schema should deserialize"); let rendered = StrictToolParametersSchema::try_from_raw(Some(&schema)) .unwrap() .into_json(); let body = &rendered["properties"]["body"]; assert!(body.get("type").is_none()); assert!(body.get("oneOf").is_none()); assert_eq!( body["anyOf"], Value::Array(vec![json!({ "type": "string" }), json!({ "type": "null" })]) ); } #[test] fn strict_tool_schema_rewrites_nullable_one_of_to_any_of() { let schema: Schema = serde_json::from_value(json!({ "type": "object", "properties": { "body": { "oneOf": [ { "type": "string" }, { "type": "null" } ] } }, "required": ["body"] })) .expect("schema should deserialize"); let rendered = StrictToolParametersSchema::try_from_raw(Some(&schema)) .unwrap() .into_json(); let body = &rendered["properties"]["body"]; assert!(body.get("type").is_none()); assert!(body.get("oneOf").is_none()); assert_eq!( body["anyOf"], Value::Array(vec![json!({ "type": "string" }), json!({ "type": "null" })]) ); } #[test] fn strict_tool_schema_strips_ref_annotation_siblings() { let schema: Schema = serde_json::from_value(json!({ "type": "object", "properties": { "request": { "$ref": "#/$defs/NestedCommentRequest", "description": "A nested payload" } }, "required": ["request"], "$defs": { "NestedCommentRequest": { "type": "object", "properties": { "body": { "type": "string" } }, "required": ["body"] } } })) .expect("schema should deserialize"); let rendered = StrictToolParametersSchema::try_from_raw(Some(&schema)) .unwrap() .into_json(); assert_eq!( rendered["properties"]["request"], json!({ "$ref": "#/$defs/NestedCommentRequest" }) ); } #[test] fn strict_tool_schema_preserves_nullable_numeric_constraints_on_the_non_null_branch() { let schema: Schema = serde_json::from_value(json!({ "$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object", "properties": { "page_size": { "type": ["integer", "null"], "format": "uint", "minimum": 0 } }, "required": ["page_size"] })) .expect("schema should deserialize"); let rendered = StrictToolParametersSchema::try_from_raw(Some(&schema)) .unwrap() .into_json(); assert_eq!( rendered.get("$schema"), Some(&json!("https://json-schema.org/draft/2020-12/schema")) ); let page_size = &rendered["properties"]["page_size"]; assert!(page_size.get("format").is_none()); assert!(page_size.get("minimum").is_none()); assert_eq!( page_size["anyOf"], Value::Array(vec![ json!({ "type": "integer", "format": "uint", "minimum": 0 }), json!({ "type": "null" }) ]) ); } #[test] fn strict_tool_schema_moves_nullable_array_constraints_into_the_array_branch() { let schema: Schema = serde_json::from_value(json!({ "type": "object", "properties": { "children": { "type": ["array", "null"], "items": { "type": "string" } } }, "required": ["children"] })) .expect("schema should deserialize"); let rendered = StrictToolParametersSchema::try_from_raw(Some(&schema)) .unwrap() .into_json(); let children = &rendered["properties"]["children"]; assert!(children.get("items").is_none()); assert_eq!( children["anyOf"], Value::Array(vec![ json!({ "type": "array", "items": { "type": "string" } }), json!({ "type": "null" }) ]) ); } #[test] fn strict_tool_schema_preserves_optional_nested_fields_before_provider_shaping() { let schema = StrictToolParametersSchema::try_from_raw(Some(&schemars::schema_for!( NestedCommentArgs ))) .unwrap(); let rendered = schema.into_json(); let nested_ref = rendered["properties"]["request"]["$ref"] .as_str() .expect("nested request should be referenced"); let nested_name = nested_ref .rsplit('/') .next() .expect("nested request ref name"); assert_eq!(rendered["additionalProperties"], Value::Bool(false)); assert!( rendered["$defs"][nested_name].get("required").is_none(), "provider-neutral parsing should not force optional nested fields into required" ); } } ================================================ FILE: swiftide-core/src/chat_completion/tools.rs ================================================ use std::cmp::Ordering; use derive_builder::Builder; use schemars::Schema; use serde::{Deserialize, Serialize}; use serde_json::{Map as JsonMap, Value as JsonValue}; use thiserror::Error; pub use super::tool_schema::{StrictToolParametersSchema, ToolSchemaError}; /// Output of a `ToolCall` which will be added as a message for the agent to use. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)] #[non_exhaustive] pub enum ToolOutput { /// Adds the result of the toolcall to messages Text(String), /// Indicates that the toolcall requires feedback, i.e. in a human-in-the-loop FeedbackRequired(Option), /// Indicates that the toolcall failed, but can be handled by the llm Fail(String), /// Stops an agent with an optional message Stop(Option), /// Indicates that the agent failed and should stop AgentFailed(Option), } impl ToolOutput { pub fn text(text: impl Into) -> Self { ToolOutput::Text(text.into()) } pub fn feedback_required(feedback: Option) -> Self { ToolOutput::FeedbackRequired(feedback) } pub fn stop() -> Self { ToolOutput::Stop(None) } pub fn stop_with_args(output: impl Into) -> Self { ToolOutput::Stop(Some(output.into())) } pub fn agent_failed(output: impl Into) -> Self { ToolOutput::AgentFailed(Some(output.into())) } pub fn fail(text: impl Into) -> Self { ToolOutput::Fail(text.into()) } pub fn content(&self) -> Option<&str> { match self { ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s), _ => None, } } /// Get the inner text if the output is a `Text` variant. pub fn as_text(&self) -> Option<&str> { match self { ToolOutput::Text(s) => Some(s), _ => None, } } /// Get the inner text if the output is a `Fail` variant. pub fn as_fail(&self) -> Option<&str> { match self { ToolOutput::Fail(s) => Some(s), _ => None, } } /// Get the inner text if the output is a `Stop` variant. pub fn as_stop(&self) -> Option<&serde_json::Value> { match self { ToolOutput::Stop(args) => args.as_ref(), _ => None, } } /// Get the inner text if the output is an `AgentFailed` variant. pub fn as_agent_failed(&self) -> Option<&serde_json::Value> { match self { ToolOutput::AgentFailed(args) => args.as_ref(), _ => None, } } /// Get the inner feedback if the output is a `FeedbackRequired` variant. pub fn as_feedback_required(&self) -> Option<&serde_json::Value> { match self { ToolOutput::FeedbackRequired(args) => args.as_ref(), _ => None, } } } impl> From for ToolOutput { fn from(value: S) -> Self { ToolOutput::Text(value.as_ref().to_string()) } } impl std::fmt::Display for ToolOutput { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ToolOutput::Text(value) => write!(f, "{value}"), ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"), ToolOutput::Stop(args) => { if let Some(value) = args { write!(f, "Stop {value}") } else { write!(f, "Stop") } } ToolOutput::FeedbackRequired(_) => { write!(f, "Feedback required") } ToolOutput::AgentFailed(args) => write!( f, "Agent failed with output: {}", args.as_ref().unwrap_or_default() ), } } } /// A tool call that can be executed by the executor #[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)] #[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))] #[builder(setter(into, strip_option))] pub struct ToolCall { id: String, name: String, #[builder(default)] args: Option, } /// Hash is used for finding tool calls that have been retried by agents impl std::hash::Hash for ToolCall { fn hash(&self, state: &mut H) { self.name.hash(state); self.args.hash(state); } } impl std::fmt::Display for ToolCall { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{id}#{name} {args}", id = self.id, name = self.name, args = self.args.as_deref().unwrap_or("") ) } } impl ToolCall { pub fn builder() -> ToolCallBuilder { ToolCallBuilder::default() } pub fn id(&self) -> &str { &self.id } pub fn name(&self) -> &str { &self.name } pub fn args(&self) -> Option<&str> { self.args.as_deref() } pub fn with_args(&mut self, args: Option) { self.args = args; } } impl ToolCallBuilder { pub fn maybe_args>>(&mut self, args: T) -> &mut Self { self.args = Some(args.into()); self } pub fn maybe_id>>(&mut self, id: T) -> &mut Self { self.id = id.into(); self } pub fn maybe_name>>(&mut self, name: T) -> &mut Self { self.name = name.into(); self } } /// A typed tool specification intended to be usable for multiple LLMs /// /// i.e. the json spec `OpenAI` uses to define their tools #[derive(Clone, Debug, Serialize, Deserialize, Builder, Default)] #[builder(setter(into), derive(Debug, Serialize, Deserialize), build_fn(skip))] #[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))] #[serde(deny_unknown_fields)] pub struct ToolSpec { /// Name of the tool pub name: String, /// Description passed to the LLM for the tool pub description: String, #[builder(default, setter(strip_option))] #[serde(skip_serializing_if = "Option::is_none")] /// Optional JSON schema describing the tool arguments pub parameters_schema: Option, } #[derive(Debug, Error)] pub enum ToolSpecError { #[error(transparent)] InvalidParametersSchema(#[from] ToolSchemaError), } #[derive(Debug, Error)] pub enum ToolSpecBuildError { #[error("missing required field `{field}`")] MissingField { field: &'static str }, #[error(transparent)] InvalidParametersSchema(#[from] ToolSchemaError), } impl ToolSpec { pub fn builder() -> ToolSpecBuilder { ToolSpecBuilder::default() } /// Returns the provider-neutral strict parameters schema for this tool. /// /// # Errors /// /// Returns an error when the configured parameters schema is not compatible /// with Swiftide's strict tool-schema contract. pub fn strict_parameters_schema(&self) -> Result { Ok(StrictToolParametersSchema::try_from_raw( self.parameters_schema.as_ref(), )?) } /// Returns the provider-neutral strict parameters schema with deterministic JSON key ordering. /// /// # Errors /// /// Returns an error when the configured parameters schema is not compatible /// with Swiftide's strict tool-schema contract. pub fn canonical_parameters_schema_json(&self) -> Result { Ok(canonicalize_json( self.strict_parameters_schema()?.into_json(), )) } } impl ToolSpecBuilder { /// Builds a tool specification and validates its parameters schema. /// /// # Errors /// /// Returns an error when a required field is missing or when the provided /// parameters schema is not compatible with Swiftide's strict tool-schema /// contract. pub fn build(&self) -> Result { let name = self .name .clone() .ok_or(ToolSpecBuildError::MissingField { field: "name" })?; let description = self .description .clone() .ok_or(ToolSpecBuildError::MissingField { field: "description", })?; let parameters_schema = self.parameters_schema.clone().unwrap_or(None); StrictToolParametersSchema::try_from_raw(parameters_schema.as_ref())?; Ok(ToolSpec { name, description, parameters_schema, }) } } impl PartialEq for ToolSpec { fn eq(&self, other: &Self) -> bool { self.name == other.name && self.description == other.description && tool_spec_schema_key(self) == tool_spec_schema_key(other) } } impl Eq for ToolSpec {} impl PartialOrd for ToolSpec { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for ToolSpec { fn cmp(&self, other: &Self) -> Ordering { self.name .cmp(&other.name) .then_with(|| self.description.cmp(&other.description)) .then_with(|| tool_spec_schema_key(self).cmp(&tool_spec_schema_key(other))) } } impl std::hash::Hash for ToolSpec { fn hash(&self, state: &mut H) { self.name.hash(state); self.description.hash(state); tool_spec_schema_key(self).hash(state); } } fn tool_spec_schema_key(spec: &ToolSpec) -> String { spec.canonical_parameters_schema_json() .ok() .or_else(|| { spec.parameters_schema .as_ref() .and_then(|schema| serde_json::to_value(schema).ok()) .map(canonicalize_json) }) .and_then(|schema| serde_json::to_string(&schema).ok()) .unwrap_or_default() } pub fn canonicalize_json(value: JsonValue) -> JsonValue { match value { JsonValue::Object(object) => { let mut keys = object.keys().cloned().collect::>(); keys.sort(); let mut sorted = JsonMap::with_capacity(object.len()); for key in keys { if let Some(child) = object.get(&key) { sorted.insert(key, canonicalize_json(child.clone())); } } JsonValue::Object(sorted) } JsonValue::Array(values) => { JsonValue::Array(values.into_iter().map(canonicalize_json).collect()) } scalar => scalar, } } #[cfg(test)] mod tests { use super::*; use serde_json::{Value, json}; use std::collections::{BTreeSet, HashSet}; use std::hash::{DefaultHasher, Hash, Hasher}; #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] struct ExampleArgs { value: String, } #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] struct NestedCommentArgs { request: NestedCommentRequest, } #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentRequest { #[serde(default, skip_serializing_if = "Option::is_none")] body: Option, #[serde(default, skip_serializing_if = "Option::is_none")] text: Option, #[serde(default, skip_serializing_if = "Option::is_none")] page_id: Option, #[serde(default, skip_serializing_if = "Option::is_none")] block_id: Option, #[serde(default, skip_serializing_if = "Option::is_none")] discussion_id: Option, } #[derive(serde::Serialize, serde::Deserialize)] #[serde(transparent)] struct FreeformObject(serde_json::Map); impl schemars::JsonSchema for FreeformObject { fn schema_name() -> std::borrow::Cow<'static, str> { "FreeformObject".into() } fn json_schema(_generator: &mut schemars::SchemaGenerator) -> Schema { serde_json::from_value(json!({ "type": "object", "additionalProperties": true })) .expect("freeform object schema should serialize") } } #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct CreateViewArgs { request: CreateViewRequest, } #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct CreateViewRequest { body: FreeformObject, } #[test] fn tool_spec_serializes_schema() { let schema = schemars::schema_for!(ExampleArgs); let spec = ToolSpec::builder() .name("example") .description("An example tool") .parameters_schema(schema) .build() .unwrap(); let json = serde_json::to_value(&spec).unwrap(); assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("example")); assert!(json.get("parameters_schema").is_some()); } #[test] fn tool_spec_is_hashable() { let schema = schemars::schema_for!(ExampleArgs); let spec = ToolSpec::builder() .name("example") .description("An example tool") .parameters_schema(schema) .build() .unwrap(); let mut set = HashSet::new(); set.insert(spec.clone()); assert!(set.contains(&spec)); } #[test] fn tool_spec_hash_is_stable_across_schema_key_order() { let first = ToolSpec::builder() .name("create_view") .description("Create a view") .parameters_schema( serde_json::from_value::(json!({ "type": "object", "properties": { "body": { "type": "string" }, "name": { "type": "string" } } })) .unwrap(), ) .build() .unwrap(); let second = ToolSpec::builder() .name("create_view") .description("Create a view") .parameters_schema( serde_json::from_value::(json!({ "properties": { "name": { "type": "string" }, "body": { "type": "string" } }, "type": "object" })) .unwrap(), ) .build() .unwrap(); let mut first_hasher = DefaultHasher::new(); first.hash(&mut first_hasher); let mut second_hasher = DefaultHasher::new(); second.hash(&mut second_hasher); assert_eq!(first_hasher.finish(), second_hasher.finish()); } #[test] fn tool_spec_order_is_stable_across_schema_key_order() { let first = ToolSpec::builder() .name("create_view") .description("Create a view") .parameters_schema( serde_json::from_value::(json!({ "type": "object", "properties": { "body": { "type": "string" }, "name": { "type": "string" } } })) .unwrap(), ) .build() .unwrap(); let second = ToolSpec::builder() .name("create_view") .description("Create a view") .parameters_schema( serde_json::from_value::(json!({ "properties": { "name": { "type": "string" }, "body": { "type": "string" } }, "type": "object" })) .unwrap(), ) .build() .unwrap(); let set = BTreeSet::from([first, second]); assert_eq!(set.len(), 1); } #[test] fn strict_parameters_schema_returns_canonical_nested_schema() { let spec = ToolSpec::builder() .name("comment") .description("Create a comment") .parameters_schema(schemars::schema_for!(NestedCommentArgs)) .build() .unwrap(); let normalized = spec.strict_parameters_schema().unwrap().into_json(); assert_eq!(normalized["type"], Value::String("object".into())); assert_eq!(normalized["additionalProperties"], Value::Bool(false)); assert_eq!( normalized["required"], Value::Array(vec![Value::String("request".into())]) ); let nested_ref = normalized["properties"]["request"]["$ref"] .as_str() .expect("nested request should be referenced"); let nested_name = nested_ref .rsplit('/') .next() .expect("nested request ref name"); assert!( normalized["$defs"][nested_name].get("required").is_none(), "strict schema parsing should preserve optional nested fields before provider shaping" ); } #[test] fn strict_parameters_schema_sets_additional_properties_false_on_nested_typed_objects() { let spec = ToolSpec::builder() .name("comment") .description("Create a comment") .parameters_schema(schemars::schema_for!(NestedCommentArgs)) .build() .unwrap(); let normalized = spec.strict_parameters_schema().unwrap().into_json(); let nested_ref = normalized["properties"]["request"]["$ref"] .as_str() .expect("nested request should be referenced"); let nested_name = nested_ref .rsplit('/') .next() .expect("nested request ref name"); assert_eq!( normalized["$defs"][nested_name]["additionalProperties"], Value::Bool(false) ); } #[test] fn tool_spec_builder_rejects_nested_freeform_objects_in_strict_mode() { let error = ToolSpec::builder() .name("create_view") .description("Create a view") .parameters_schema(schemars::schema_for!(CreateViewArgs)) .build() .expect_err("freeform object should be rejected in strict mode"); let message = error.to_string(); assert!(message.contains("strict tool schemas do not support open object schemas")); assert!(message.contains("FreeformObject")); } #[test] fn strict_parameters_schema_preserves_optional_nested_fields() { let spec = ToolSpec::builder() .name("comment") .description("Create a comment") .parameters_schema(schemars::schema_for!(NestedCommentArgs)) .build() .unwrap(); let normalized = spec.strict_parameters_schema().unwrap().into_json(); assert_eq!(normalized["type"], Value::String("object".into())); assert_eq!(normalized["additionalProperties"], Value::Bool(false)); assert_eq!( normalized["required"], Value::Array(vec![Value::String("request".into())]) ); let nested_ref = normalized["properties"]["request"]["$ref"] .as_str() .expect("nested request should be referenced"); let nested_name = nested_ref .rsplit('/') .next() .expect("nested request ref name"); assert_eq!( normalized["$defs"][nested_name]["additionalProperties"], Value::Bool(false) ); assert!(normalized["$defs"][nested_name].get("required").is_none()); } } ================================================ FILE: swiftide-core/src/chat_completion/traits.rs ================================================ use anyhow::Result; use async_trait::async_trait; use dyn_clone::DynClone; use futures_util::Stream; use std::{borrow::Cow, pin::Pin, sync::Arc}; use crate::AgentContext; use super::{ ToolCall, ToolOutput, ToolSpec, chat_completion_request::ChatCompletionRequest, chat_completion_response::ChatCompletionResponse, errors::{LanguageModelError, ToolError}, }; pub type ChatCompletionStream = Pin> + Send>>; #[async_trait] pub trait ChatCompletion: Send + Sync + DynClone { async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result; /// Stream the completion response. If it's not supported, it will return a single /// response async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { Box::pin(tokio_stream::iter(vec![self.complete(request).await])) } } #[async_trait] impl ChatCompletion for Box { async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result { (**self).complete(request).await } async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { (**self).complete_stream(request).await } } #[async_trait] impl ChatCompletion for &dyn ChatCompletion { async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result { (**self).complete(request).await } async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { (**self).complete_stream(request).await } } #[async_trait] impl ChatCompletion for &T where T: ChatCompletion + Clone + 'static, { async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result { (**self).complete(request).await } async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { (**self).complete_stream(request).await } } impl From<&LLM> for Box where LLM: ChatCompletion + Clone + 'static, { fn from(llm: &LLM) -> Self { Box::new(llm.clone()) as Box } } dyn_clone::clone_trait_object!(ChatCompletion); /// The `Tool` trait is the main interface for chat completion and agent tools. /// /// `swiftide-macros` provides a set of macros to generate implementations of this trait. If you /// need more control over the implementation, you can implement the trait manually. /// /// The `ToolSpec` is what will end up with the LLM. A builder is provided. The `name` is expected /// to be unique, and is used to identify the tool. It should be the same as the name in the /// `ToolSpec`. #[async_trait] pub trait Tool: Send + Sync + DynClone { // tbd async fn invoke( &self, agent_context: &dyn AgentContext, tool_call: &ToolCall, ) -> Result; fn name(&self) -> Cow<'_, str>; fn tool_spec(&self) -> ToolSpec; fn boxed<'a>(self) -> Box where Self: Sized + 'a, { Box::new(self) as Box } } /// A toolbox is a collection of tools /// /// It can be a list, an mcp client, or anything else we can think of. /// /// This allows agents to not know their tools when they are created, and to get them at runtime. /// /// It also allows for tools to be dynamically loaded and unloaded, etc. #[async_trait] pub trait ToolBox: Send + Sync + DynClone { async fn available_tools(&self) -> Result>>; fn name(&self) -> Cow<'_, str> { Cow::Borrowed("Unnamed ToolBox") } fn boxed<'a>(self) -> Box where Self: Sized + 'a, { Box::new(self) as Box } } #[async_trait] impl ToolBox for Vec> { async fn available_tools(&self) -> Result>> { Ok(self.clone()) } } #[async_trait] impl ToolBox for Box { async fn available_tools(&self) -> Result>> { (**self).available_tools().await } } #[async_trait] impl ToolBox for Arc { async fn available_tools(&self) -> Result>> { (**self).available_tools().await } } #[async_trait] impl ToolBox for &dyn ToolBox { async fn available_tools(&self) -> Result>> { (**self).available_tools().await } } #[async_trait] impl ToolBox for &[Box] { async fn available_tools(&self) -> Result>> { Ok(self.to_vec()) } } #[async_trait] impl ToolBox for [Box] { async fn available_tools(&self) -> Result>> { Ok(self.to_vec()) } } dyn_clone::clone_trait_object!(ToolBox); #[async_trait] impl Tool for Box { async fn invoke( &self, agent_context: &dyn AgentContext, tool_call: &ToolCall, ) -> Result { (**self).invoke(agent_context, tool_call).await } fn name(&self) -> Cow<'_, str> { (**self).name() } fn tool_spec(&self) -> ToolSpec { (**self).tool_spec() } } dyn_clone::clone_trait_object!(Tool); /// Tools are identified and unique by name /// These allow comparison and lookups impl PartialEq for Box { fn eq(&self, other: &Self) -> bool { self.name() == other.name() } } impl Eq for Box {} impl std::hash::Hash for Box { fn hash(&self, state: &mut H) { self.name().hash(state); } } ================================================ FILE: swiftide-core/src/document.rs ================================================ //! Documents are the main data structure that is retrieved via the query pipeline //! //! Retrievers are expected to eagerly set any configured metadata on the document, with the same //! field name used during indexing if applicable. use std::fmt; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use crate::{metadata::Metadata, util::debug_long_utf8}; /// A document represents a single unit of retrieved text #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Builder)] #[builder(setter(into))] pub struct Document { #[builder(default)] metadata: Metadata, content: String, } impl From for serde_json::Value { fn from(document: Document) -> Self { serde_json::json!({ "metadata": document.metadata, "content": document.content, }) } } impl From<&Document> for serde_json::Value { fn from(document: &Document) -> Self { serde_json::json!({ "metadata": document.metadata, "content": document.content, }) } } impl PartialOrd for Document { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for Document { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.content.cmp(&other.content) } } impl fmt::Debug for Document { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Document") .field("metadata", &self.metadata) .field("content", &debug_long_utf8(&self.content, 100)) .finish() } } impl> From for Document { fn from(value: T) -> Self { Document::new(value.as_ref(), None) } } impl Document { pub fn new(content: impl Into, metadata: Option) -> Self { Self { metadata: metadata.unwrap_or_default(), content: content.into(), } } pub fn builder() -> DocumentBuilder { DocumentBuilder::default() } pub fn content(&self) -> &str { &self.content } pub fn metadata(&self) -> &Metadata { &self.metadata } pub fn bytes(&self) -> &[u8] { self.content.as_bytes() } } #[cfg(test)] mod tests { use super::*; use crate::metadata::Metadata; #[test] fn test_document_creation() { let content = "Test content"; let metadata = Metadata::from([("some", "metadata")]); let document = Document::new(content, Some(metadata.clone())); assert_eq!(document.content(), content); assert_eq!(document.metadata(), &metadata); } #[test] fn test_document_default_metadata() { let content = "Test content"; let document = Document::new(content, None); assert_eq!(document.content(), content); assert_eq!(document.metadata(), &Metadata::default()); } #[test] fn test_document_from_str() { let content = "Test content"; let document: Document = content.into(); assert_eq!(document.content(), content); assert_eq!(document.metadata(), &Metadata::default()); } #[test] fn test_document_partial_ord() { let doc1 = Document::new("A", None); let doc2 = Document::new("B", None); assert!(doc1 < doc2); } #[test] fn test_document_ord() { let doc1 = Document::new("A", None); let doc2 = Document::new("B", None); assert!(doc1.cmp(&doc2) == std::cmp::Ordering::Less); } #[test] fn test_document_debug() { let content = "Test content"; let document = Document::new(content, None); let debug_str = format!("{document:?}"); assert!(debug_str.contains("Document")); assert!(debug_str.contains("metadata")); assert!(debug_str.contains("content")); } #[test] fn test_document_to_json() { let content = "Test content"; let metadata = Metadata::from([("some", "metadata")]); let document = Document::new(content, Some(metadata.clone())); let json_value: serde_json::Value = document.into(); assert_eq!(json_value["content"], content); assert_eq!(json_value["metadata"], serde_json::json!(metadata)); } #[test] fn test_document_ref_to_json() { let content = "Test content"; let metadata = Metadata::from([("some", "metadata")]); let document = Document::new(content, Some(metadata.clone())); let json_value: serde_json::Value = (&document).into(); assert_eq!(json_value["content"], content); assert_eq!(json_value["metadata"], serde_json::json!(metadata)); } } ================================================ FILE: swiftide-core/src/indexing_decorators.rs ================================================ use std::fmt::Debug; use crate::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; use crate::stream_backoff::{StreamBackoff, TokioSleeper}; use crate::{ChatCompletion, ChatCompletionStream}; use crate::{EmbeddingModel, Embeddings, SimplePrompt, prompt::Prompt}; use crate::chat_completion::errors::LanguageModelError; use anyhow::Result; use async_trait::async_trait; use futures_util::{StreamExt as _, TryStreamExt as _}; use std::time::Duration; /// Backoff configuration for api calls. /// Each time an api call fails backoff will wait an increasing period of time for each subsequent /// retry attempt. see for more details. #[derive(Debug, Clone, Copy)] pub struct BackoffConfiguration { /// Initial interval in seconds between retries pub initial_interval_sec: u64, /// The factor by which the interval is multiplied on each retry attempt pub multiplier: f64, /// Introduces randomness to avoid retry storms pub randomization_factor: f64, /// Total time all attempts are allowed in seconds. Once a retry must wait longer than this, /// the request is considered to have failed. pub max_elapsed_time_sec: u64, } impl Default for BackoffConfiguration { fn default() -> Self { Self { initial_interval_sec: 1, multiplier: 2.0, randomization_factor: 0.5, max_elapsed_time_sec: 60, } } } #[derive(Debug, Clone)] pub struct LanguageModelWithBackOff { pub(crate) inner: P, config: BackoffConfiguration, } impl LanguageModelWithBackOff

{ pub fn new(client: P, config: BackoffConfiguration) -> Self { Self { inner: client, config, } } pub(crate) fn strategy(&self) -> backoff::ExponentialBackoff { backoff::ExponentialBackoffBuilder::default() .with_initial_interval(Duration::from_secs(self.config.initial_interval_sec)) .with_multiplier(self.config.multiplier) .with_max_elapsed_time(Some(Duration::from_secs(self.config.max_elapsed_time_sec))) .with_randomization_factor(self.config.randomization_factor) .build() } } #[async_trait] impl SimplePrompt for LanguageModelWithBackOff

{ async fn prompt(&self, prompt: Prompt) -> Result { let strategy = self.strategy(); let op = || { let prompt = prompt.clone(); async { self.inner.prompt(prompt).await.map_err(|e| match e { LanguageModelError::ContextLengthExceeded(e) => { backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e)) } LanguageModelError::PermanentError(e) => { backoff::Error::Permanent(LanguageModelError::PermanentError(e)) } LanguageModelError::TransientError(e) => { backoff::Error::transient(LanguageModelError::TransientError(e)) } }) } }; backoff::future::retry(strategy, op).await } fn name(&self) -> &'static str { self.inner.name() } } #[async_trait] impl EmbeddingModel for LanguageModelWithBackOff

{ async fn embed(&self, input: Vec) -> Result { self.inner.embed(input).await } fn name(&self) -> &'static str { self.inner.name() } } #[async_trait] impl ChatCompletion for LanguageModelWithBackOff { async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result { let strategy = self.strategy(); let op = || async move { self.inner.complete(request).await.map_err(|e| match e { LanguageModelError::ContextLengthExceeded(e) => { backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e)) } LanguageModelError::PermanentError(e) => { backoff::Error::Permanent(LanguageModelError::PermanentError(e)) } LanguageModelError::TransientError(e) => { backoff::Error::transient(LanguageModelError::TransientError(e)) } }) }; backoff::future::retry(strategy, op).await } async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { let strategy = self.strategy(); let stream = self.inner.complete_stream(request).await; let stream = stream .map_err(|e| match e { LanguageModelError::ContextLengthExceeded(e) => { backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e)) } LanguageModelError::PermanentError(e) => { backoff::Error::Permanent(LanguageModelError::PermanentError(e)) } LanguageModelError::TransientError(e) => { backoff::Error::transient(LanguageModelError::TransientError(e)) } }) .boxed(); StreamBackoff::new(stream, strategy, TokioSleeper) .map_err(|e| match e { backoff::Error::Permanent(e) => e, backoff::Error::Transient { err, .. } => err, }) .boxed() } } #[cfg(test)] mod tests { use uuid::Uuid; use super::*; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; #[derive(Debug, Clone)] struct MockSimplePrompt { call_count: Arc, should_fail_count: usize, error_type: MockErrorType, } #[derive(Debug, Clone, Copy)] enum MockErrorType { Transient, Permanent, ContextLengthExceeded, } #[derive(Clone)] struct MockChatCompletion { call_count: Arc, should_fail_count: usize, error_type: MockErrorType, } #[async_trait] impl ChatCompletion for MockChatCompletion { async fn complete( &self, _request: &ChatCompletionRequest<'_>, ) -> Result { let count = self.call_count.fetch_add(1, Ordering::SeqCst); if count < self.should_fail_count { match self.error_type { MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new( std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"), ))), MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new( std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"), ))), MockErrorType::ContextLengthExceeded => Err( LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new( std::io::ErrorKind::InvalidInput, "Context length exceeded", ))), ), } } else { Ok(ChatCompletionResponse { id: Uuid::new_v4(), message: Some("Success response".to_string()), tool_calls: None, delta: None, usage: None, reasoning: None, }) } } } #[async_trait] impl SimplePrompt for MockSimplePrompt { async fn prompt(&self, _prompt: Prompt) -> Result { let count = self.call_count.fetch_add(1, Ordering::SeqCst); if count < self.should_fail_count { match self.error_type { MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new( std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"), ))), MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new( std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"), ))), MockErrorType::ContextLengthExceeded => Err( LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new( std::io::ErrorKind::InvalidInput, "Context length exceeded", ))), ), } } else { Ok("Success response".to_string()) } } fn name(&self) -> &'static str { "MockSimplePrompt" } } #[tokio::test] async fn test_language_model_with_backoff_retries_transient_errors() { let call_count = Arc::new(AtomicUsize::new(0)); let mock_prompt = MockSimplePrompt { call_count: call_count.clone(), should_fail_count: 2, // Fail twice, succeed on third attempt error_type: MockErrorType::Transient, }; let config = BackoffConfiguration { initial_interval_sec: 1, max_elapsed_time_sec: 10, multiplier: 1.5, randomization_factor: 0.5, }; let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config); let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await; assert!(result.is_ok()); assert_eq!(call_count.load(Ordering::SeqCst), 3); assert_eq!(result.unwrap(), "Success response"); } #[tokio::test] async fn test_language_model_with_backoff_does_not_retry_permanent_errors() { let call_count = Arc::new(AtomicUsize::new(0)); let mock_prompt = MockSimplePrompt { call_count: call_count.clone(), should_fail_count: 1, error_type: MockErrorType::Permanent, }; let config = BackoffConfiguration { initial_interval_sec: 1, max_elapsed_time_sec: 10, multiplier: 1.5, randomization_factor: 0.5, }; let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config); let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await; assert!(result.is_err()); assert_eq!(call_count.load(Ordering::SeqCst), 1); match result { Err(LanguageModelError::PermanentError(_)) => {} // Expected _ => panic!("Expected PermanentError"), } } #[tokio::test] async fn test_language_model_with_backoff_does_not_retry_context_length_errors() { let call_count = Arc::new(AtomicUsize::new(0)); let mock_prompt = MockSimplePrompt { call_count: call_count.clone(), should_fail_count: 1, error_type: MockErrorType::ContextLengthExceeded, }; let config = BackoffConfiguration { initial_interval_sec: 1, max_elapsed_time_sec: 10, multiplier: 1.5, randomization_factor: 0.5, }; let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config); let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await; assert!(result.is_err()); assert_eq!(call_count.load(Ordering::SeqCst), 1); match result { Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected _ => panic!("Expected ContextLengthExceeded"), } } #[tokio::test] async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() { let call_count = Arc::new(AtomicUsize::new(0)); let mock_chat = MockChatCompletion { call_count: call_count.clone(), should_fail_count: 2, // Fail twice, succeed on third attempt error_type: MockErrorType::Transient, }; let config = BackoffConfiguration { initial_interval_sec: 1, max_elapsed_time_sec: 10, multiplier: 1.5, randomization_factor: 0.5, }; let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config); let request: ChatCompletionRequest<'static> = Vec::new().into(); let result = model_with_backoff.complete(&request).await; assert!(result.is_ok()); assert_eq!(call_count.load(Ordering::SeqCst), 3); assert_eq!( result.unwrap().message, Some("Success response".to_string()) ); } #[tokio::test] async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() { let call_count = Arc::new(AtomicUsize::new(0)); let mock_chat = MockChatCompletion { call_count: call_count.clone(), should_fail_count: 2, // Would fail twice if retried error_type: MockErrorType::Permanent, }; let config = BackoffConfiguration { initial_interval_sec: 1, max_elapsed_time_sec: 10, multiplier: 1.5, randomization_factor: 0.5, }; let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config); let request: ChatCompletionRequest<'static> = Vec::new().into(); let result = model_with_backoff.complete(&request).await; assert!(result.is_err()); assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once match result { Err(LanguageModelError::PermanentError(_)) => {} // Expected _ => panic!("Expected PermanentError, got {result:?}"), } } #[tokio::test] async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors() { let call_count = Arc::new(AtomicUsize::new(0)); let mock_chat = MockChatCompletion { call_count: call_count.clone(), should_fail_count: 2, // Would fail twice if retried error_type: MockErrorType::ContextLengthExceeded, }; let config = BackoffConfiguration { initial_interval_sec: 1, max_elapsed_time_sec: 10, multiplier: 1.5, randomization_factor: 0.5, }; let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config); let request: ChatCompletionRequest<'static> = Vec::new().into(); let result = model_with_backoff.complete(&request).await; assert!(result.is_err()); assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once match result { Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected _ => panic!("Expected ContextLengthExceeded, got {result:?}"), } } } ================================================ FILE: swiftide-core/src/indexing_defaults.rs ================================================ use std::sync::Arc; use crate::SimplePrompt; #[derive(Debug, Default, Clone)] pub struct IndexingDefaults(Arc); #[derive(Debug, Default)] pub struct IndexingDefaultsInner { simple_prompt: Option>, } impl IndexingDefaults { pub fn simple_prompt(&self) -> Option<&dyn SimplePrompt> { self.0.simple_prompt.as_deref() } pub fn from_simple_prompt(simple_prompt: Box) -> Self { Self(Arc::new(IndexingDefaultsInner { simple_prompt: Some(simple_prompt), })) } } ================================================ FILE: swiftide-core/src/indexing_stream.rs ================================================ #![allow(clippy::from_over_into)] //! This module defines the `IndexingStream` type, which is used internally by a pipeline for //! handling asynchronous streams of `Node` items in the indexing pipeline. use crate::node::{Chunk, Node}; use anyhow::Result; use futures_util::stream::{self, Stream}; use std::pin::Pin; use tokio::sync::mpsc::Receiver; pub use futures_util::StreamExt; // We need to inform the compiler that `inner` is pinned as well /// An asynchronous stream of `Node` items. /// /// Wraps an internal stream of `Result>` items. /// /// Streams, iterators and vectors of `Result>` can be converted into an `IndexingStream`. #[pin_project::pin_project] pub struct IndexingStream { #[pin] pub(crate) inner: Pin>> + Send>>, } impl Stream for IndexingStream { type Item = Result>; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); this.inner.poll_next(cx) } } impl Into> for Vec>> { fn into(self) -> IndexingStream { IndexingStream::iter(self) } } impl Into> for Vec> { fn into(self) -> IndexingStream { IndexingStream::from_nodes(self) } } // impl Into for anyhow::Error { // fn into(self) -> IndexingStream { // IndexingStream::iter(vec![Err(self)]) // } // } impl Into> for Result>> { fn into(self) -> IndexingStream { match self { Ok(nodes) => IndexingStream::iter(nodes.into_iter().map(Ok)), Err(err) => IndexingStream::iter(vec![Err(err)]), } } } impl Into> for Pin>> + Send>> { fn into(self) -> IndexingStream { IndexingStream { inner: self } } } impl Into> for Receiver>> { fn into(self) -> IndexingStream { IndexingStream { inner: tokio_stream::wrappers::ReceiverStream::new(self).boxed(), } } } impl From for IndexingStream { fn from(err: anyhow::Error) -> Self { IndexingStream::iter(vec![Err(err)]) } } impl IndexingStream { pub fn empty() -> Self { IndexingStream { inner: stream::empty().boxed(), } } /// Creates an `IndexingStream` from an iterator of `Result>`. /// /// WARN: Also works with Err items directly, which will result /// in an _incorrect_ stream pub fn iter(iter: I) -> Self where I: IntoIterator>> + Send + 'static, ::IntoIter: Send, { IndexingStream { inner: stream::iter(iter).boxed(), } } pub fn from_nodes(nodes: Vec>) -> Self { IndexingStream::iter(nodes.into_iter().map(Ok)) } } ================================================ FILE: swiftide-core/src/indexing_traits.rs ================================================ //! Traits in Swiftide allow for easy extendability //! //! All steps defined in the indexing pipeline and the generic transformers can also take a //! trait. To bring your own transformers, models and loaders, all you need to do is implement the //! trait and it should work out of the box. use crate::Embeddings; use crate::node::{Chunk, Node}; use crate::{ SparseEmbeddings, indexing_defaults::IndexingDefaults, indexing_stream::IndexingStream, }; use std::fmt::Debug; use std::sync::Arc; use crate::chat_completion::errors::LanguageModelError; use crate::prompt::Prompt; use anyhow::Result; use async_trait::async_trait; pub use dyn_clone::DynClone; /// All traits are easily mockable under tests #[cfg(feature = "test-utils")] #[doc(hidden)] use mockall::{mock, predicate::str}; use schemars::{JsonSchema, schema_for}; use serde::de::DeserializeOwned; #[async_trait] /// Transforms single nodes into single nodes pub trait Transformer: Send + Sync + DynClone { type Input: Chunk; type Output: Chunk; async fn transform_node(&self, node: Node) -> Result>; /// Overrides the default concurrency of the pipeline fn concurrency(&self) -> Option { None } fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!( Transformer); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub Transformer {} #[async_trait] impl Transformer for Transformer { type Input = String; type Output = String; async fn transform_node(&self, node: Node) -> Result>; fn concurrency(&self) -> Option; fn name(&self) -> &'static str; } impl Clone for Transformer { fn clone(&self) -> Self; } } #[async_trait] impl Transformer for Box> { type Input = I; type Output = O; async fn transform_node(&self, node: Node) -> Result> { self.as_ref().transform_node(node).await } fn concurrency(&self) -> Option { self.as_ref().concurrency() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl Transformer for Arc> { type Input = I; type Output = O; async fn transform_node(&self, node: Node) -> Result> { self.as_ref().transform_node(node).await } fn concurrency(&self) -> Option { self.as_ref().concurrency() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl Transformer for &dyn Transformer { type Input = I; type Output = O; async fn transform_node(&self, node: Node) -> Result> { (*self).transform_node(node).await } fn concurrency(&self) -> Option { (*self).concurrency() } } #[async_trait] /// Use a closure as a transformer // TODO: Find a way to make this work with full generics impl Transformer for F where F: Fn(Node) -> Result> + Send + Sync + Clone, { type Input = String; type Output = String; async fn transform_node(&self, node: Node) -> Result> { self(node) } } #[async_trait] /// Transforms batched single nodes into streams of nodes pub trait BatchableTransformer: Send + Sync + DynClone { type Input: Chunk; type Output: Chunk; /// Transforms a batch of nodes into a stream of nodes async fn batch_transform(&self, nodes: Vec>) -> IndexingStream; /// Overrides the default concurrency of the pipeline fn concurrency(&self) -> Option { None } fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } /// Overrides the default batch size of the pipeline fn batch_size(&self) -> Option { None } } dyn_clone::clone_trait_object!( BatchableTransformer); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub BatchableTransformer {} #[async_trait] impl BatchableTransformer for BatchableTransformer { type Input = String; type Output = String; async fn batch_transform(&self, nodes: Vec>) -> IndexingStream; fn name(&self) -> &'static str; fn batch_size(&self) -> Option; fn concurrency(&self) -> Option; } impl Clone for BatchableTransformer { fn clone(&self) -> Self; } } #[async_trait] /// Use a closure as a batchable transformer impl BatchableTransformer for F where F: Fn(Vec>) -> IndexingStream + Send + Sync + Clone, { type Input = String; type Output = String; async fn batch_transform(&self, nodes: Vec>) -> IndexingStream { self(nodes) } } #[async_trait] impl BatchableTransformer for Box> { type Input = I; type Output = O; async fn batch_transform(&self, nodes: Vec>) -> IndexingStream { self.as_ref().batch_transform(nodes).await } fn concurrency(&self) -> Option { self.as_ref().concurrency() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl BatchableTransformer for Arc> { type Input = I; type Output = O; async fn batch_transform(&self, nodes: Vec>) -> IndexingStream { self.as_ref().batch_transform(nodes).await } fn concurrency(&self) -> Option { self.as_ref().concurrency() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl BatchableTransformer for &dyn BatchableTransformer { type Input = I; type Output = O; async fn batch_transform(&self, nodes: Vec>) -> IndexingStream { (*self).batch_transform(nodes).await } fn concurrency(&self) -> Option { (*self).concurrency() } } /// Starting point of a stream pub trait Loader: DynClone + Send + Sync { type Output: Chunk; fn into_stream(self) -> IndexingStream; /// Intended for use with Box /// /// Only needed if you use trait objects (Box) /// /// # Example /// /// ```ignore /// fn into_stream_boxed(self: Box) -> IndexingStream { /// self.into_stream() /// } /// ``` fn into_stream_boxed(self: Box) -> IndexingStream { unimplemented!( "Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type" ) } fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!( Loader); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub Loader {} #[async_trait] impl Loader for Loader { type Output = String; fn into_stream(self) -> IndexingStream; fn into_stream_boxed(self: Box) -> IndexingStream; fn name(&self) -> &'static str; } impl Clone for Loader { fn clone(&self) -> Self; } } impl Loader for Box> { type Output = O; fn into_stream(self) -> IndexingStream { Loader::into_stream_boxed(self) } fn into_stream_boxed(self: Box) -> IndexingStream { Loader::into_stream(*self) } fn name(&self) -> &'static str { self.as_ref().name() } } impl Loader for &dyn Loader { type Output = O; fn into_stream(self) -> IndexingStream { Loader::into_stream_boxed(Box::new(self)) } fn into_stream_boxed(self: Box) -> IndexingStream { Loader::into_stream(*self) } } #[async_trait] /// Turns one node into many nodes pub trait ChunkerTransformer: Send + Sync + DynClone { type Input: Chunk; type Output: Chunk; async fn transform_node(&self, node: Node) -> IndexingStream; /// Overrides the default concurrency of the pipeline fn concurrency(&self) -> Option { None } fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!( ChunkerTransformer); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub ChunkerTransformer {} #[async_trait] impl ChunkerTransformer for ChunkerTransformer { type Input = String; type Output = String; async fn transform_node(&self, node: Node) -> IndexingStream; fn name(&self) -> &'static str; fn concurrency(&self) -> Option; } impl Clone for ChunkerTransformer { fn clone(&self) -> Self; } } #[async_trait] impl ChunkerTransformer for Box> { type Input = I; type Output = O; async fn transform_node(&self, node: Node) -> IndexingStream { self.as_ref().transform_node(node).await } fn concurrency(&self) -> Option { self.as_ref().concurrency() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl ChunkerTransformer for Arc> { type Input = I; type Output = O; async fn transform_node(&self, node: Node) -> IndexingStream { self.as_ref().transform_node(node).await } fn concurrency(&self) -> Option { self.as_ref().concurrency() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl ChunkerTransformer for &dyn ChunkerTransformer { type Input = I; type Output = O; async fn transform_node(&self, node: Node) -> IndexingStream { (*self).transform_node(node).await } fn concurrency(&self) -> Option { (*self).concurrency() } } #[async_trait] impl ChunkerTransformer for F where F: Fn(Node) -> IndexingStream + Send + Sync + Clone, { async fn transform_node(&self, node: Node) -> IndexingStream { self(node) } type Input = String; type Output = String; } #[async_trait] /// Caches nodes, typically by their path and hash /// Recommended to namespace on the storage /// /// For now just bool return value for easy filter pub trait NodeCache: Send + Sync + Debug + DynClone { type Input: Chunk; async fn get(&self, node: &Node) -> bool; async fn set(&self, node: &Node); /// Optionally provide a method to clear the cache async fn clear(&self) -> Result<()> { unimplemented!("Clear not implemented") } fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!( NodeCache); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub NodeCache {} #[async_trait] impl NodeCache for NodeCache { type Input = String; async fn get(&self, node: &Node) -> bool; async fn set(&self, node: &Node); async fn clear(&self) -> Result<()>; fn name(&self) -> &'static str; } impl Clone for NodeCache { fn clone(&self) -> Self; } } #[async_trait] impl NodeCache for Box> { type Input = T; async fn get(&self, node: &Node) -> bool { self.as_ref().get(node).await } async fn set(&self, node: &Node) { self.as_ref().set(node).await; } async fn clear(&self) -> Result<()> { self.as_ref().clear().await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl NodeCache for Arc> { type Input = T; async fn get(&self, node: &Node) -> bool { self.as_ref().get(node).await } async fn set(&self, node: &Node) { self.as_ref().set(node).await; } async fn clear(&self) -> Result<()> { self.as_ref().clear().await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl NodeCache for &dyn NodeCache { type Input = T; async fn get(&self, node: &Node) -> bool { (*self).get(node).await } async fn set(&self, node: &Node) { (*self).set(node).await; } async fn clear(&self) -> Result<()> { (*self).clear().await } } #[async_trait] /// Embeds a list of strings and returns its embeddings. /// Assumes the strings will be moved. pub trait EmbeddingModel: Send + Sync + Debug + DynClone { async fn embed(&self, input: Vec) -> Result; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!(EmbeddingModel); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub EmbeddingModel {} #[async_trait] impl EmbeddingModel for EmbeddingModel { async fn embed(&self, input: Vec) -> Result; fn name(&self) -> &'static str; } impl Clone for EmbeddingModel { fn clone(&self) -> Self; } } #[async_trait] impl EmbeddingModel for Box { async fn embed(&self, input: Vec) -> Result { self.as_ref().embed(input).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl EmbeddingModel for Arc { async fn embed(&self, input: Vec) -> Result { self.as_ref().embed(input).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl EmbeddingModel for &dyn EmbeddingModel { async fn embed(&self, input: Vec) -> Result { (*self).embed(input).await } } #[async_trait] /// Embeds a list of strings and returns its embeddings. /// Assumes the strings will be moved. pub trait SparseEmbeddingModel: Send + Sync + Debug + DynClone { async fn sparse_embed( &self, input: Vec, ) -> Result; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!(SparseEmbeddingModel); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub SparseEmbeddingModel {} #[async_trait] impl SparseEmbeddingModel for SparseEmbeddingModel { async fn sparse_embed(&self, input: Vec) -> Result; fn name(&self) -> &'static str; } impl Clone for SparseEmbeddingModel { fn clone(&self) -> Self; } } #[async_trait] impl SparseEmbeddingModel for Box { async fn sparse_embed( &self, input: Vec, ) -> Result { self.as_ref().sparse_embed(input).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl SparseEmbeddingModel for Arc { async fn sparse_embed( &self, input: Vec, ) -> Result { self.as_ref().sparse_embed(input).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl SparseEmbeddingModel for &dyn SparseEmbeddingModel { async fn sparse_embed( &self, input: Vec, ) -> Result { (*self).sparse_embed(input).await } } #[async_trait] /// Given a string prompt, queries an LLM pub trait SimplePrompt: Debug + Send + Sync + DynClone { // Takes a simple prompt, prompts the llm and returns the response async fn prompt(&self, prompt: Prompt) -> Result; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!(SimplePrompt); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub SimplePrompt {} #[async_trait] impl SimplePrompt for SimplePrompt { async fn prompt(&self, prompt: Prompt) -> Result; fn name(&self) -> &'static str; } impl Clone for SimplePrompt { fn clone(&self) -> Self; } } #[async_trait] impl SimplePrompt for Box { async fn prompt(&self, prompt: Prompt) -> Result { self.as_ref().prompt(prompt).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl SimplePrompt for Arc { async fn prompt(&self, prompt: Prompt) -> Result { self.as_ref().prompt(prompt).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl SimplePrompt for &dyn SimplePrompt { async fn prompt(&self, prompt: Prompt) -> Result { (*self).prompt(prompt).await } } #[async_trait] /// Persists nodes pub trait Persist: Debug + Send + Sync + DynClone { type Input: Chunk; type Output: Chunk; async fn setup(&self) -> Result<()>; async fn store(&self, node: Node) -> Result>; async fn batch_store(&self, nodes: Vec>) -> IndexingStream; fn batch_size(&self) -> Option { None } fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!( Persist); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub Persist {} #[async_trait] impl Persist for Persist { type Input = String; type Output = String; async fn setup(&self) -> Result<()>; async fn store(&self, node: Node) -> Result>; async fn batch_store(&self, nodes: Vec>) -> IndexingStream; fn batch_size(&self) -> Option; fn name(&self) -> &'static str; } impl Clone for Persist { fn clone(&self) -> Self; } } #[async_trait] impl Persist for Box> { type Input = I; type Output = O; async fn setup(&self) -> Result<()> { self.as_ref().setup().await } async fn store(&self, node: Node) -> Result> { self.as_ref().store(node).await } async fn batch_store(&self, nodes: Vec>) -> IndexingStream { self.as_ref().batch_store(nodes).await } fn batch_size(&self) -> Option { self.as_ref().batch_size() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl Persist for Arc> { type Input = I; type Output = O; async fn setup(&self) -> Result<()> { self.as_ref().setup().await } async fn store(&self, node: Node) -> Result> { self.as_ref().store(node).await } async fn batch_store(&self, nodes: Vec>) -> IndexingStream { self.as_ref().batch_store(nodes).await } fn batch_size(&self) -> Option { self.as_ref().batch_size() } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl Persist for &dyn Persist { type Input = I; type Output = O; async fn setup(&self) -> Result<()> { (*self).setup().await } async fn store(&self, node: Node) -> Result> { (*self).store(node).await } async fn batch_store(&self, nodes: Vec>) -> IndexingStream { (*self).batch_store(nodes).await } fn batch_size(&self) -> Option { (*self).batch_size() } } /// Allows for passing defaults from the pipeline to the transformer /// Required for batch transformers as at least a marker, implementation is not required pub trait WithIndexingDefaults { fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {} } /// Allows for passing defaults from the pipeline to the batch transformer /// Required for batch transformers as at least a marker, implementation is not required pub trait WithBatchIndexingDefaults { fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {} } impl WithIndexingDefaults for dyn Transformer {} impl WithIndexingDefaults for Box> { fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) { self.as_mut().with_indexing_defaults(indexing_defaults); } } impl WithBatchIndexingDefaults for dyn BatchableTransformer {} impl WithBatchIndexingDefaults for Box> { fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) { self.as_mut().with_indexing_defaults(indexing_defaults); } } impl WithIndexingDefaults for F where F: Fn(Node) -> Result> {} impl WithBatchIndexingDefaults for F where F: Fn(Vec>) -> IndexingStream {} #[cfg(feature = "test-utils")] impl WithIndexingDefaults for MockTransformer {} // // #[cfg(feature = "test-utils")] impl WithBatchIndexingDefaults for MockBatchableTransformer {} #[async_trait] /// Given a string prompt, queries an LLM to return structured data pub trait StructuredPrompt: Debug + Send + Sync + DynClone { async fn structured_prompt( &self, prompt: Prompt, ) -> Result; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } /// Helper trait object to call structured prompt with dynamic dispatch /// /// Internally Swiftide only implements this trait, as implementing `DynStructuredPrompt` gives /// `StructuredPrompt` for free #[async_trait] pub trait DynStructuredPrompt: Debug + Send + Sync + DynClone { async fn structured_prompt_dyn( &self, prompt: Prompt, schema: schemars::Schema, ) -> Result; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!(DynStructuredPrompt); #[async_trait] impl StructuredPrompt for C where C: DynStructuredPrompt + Debug + Send + Sync + DynClone, { async fn structured_prompt( &self, prompt: Prompt, ) -> Result { // Call with T = serde_json::Value let schema = schema_for!(T); let val = self.structured_prompt_dyn(prompt, schema).await?; let parsed = serde_json::from_value(val).map_err(LanguageModelError::permanent)?; Ok(parsed) } } ================================================ FILE: swiftide-core/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] #![cfg_attr(coverage_nightly, feature(coverage_attribute))] pub mod agent_traits; pub mod chat_completion; pub mod indexing_decorators; mod indexing_defaults; mod indexing_stream; pub mod indexing_traits; mod node; mod query; mod query_stream; pub mod query_traits; mod search_strategies; mod stream_backoff; pub mod token_estimation; pub mod type_aliases; pub mod document; pub mod prompt; pub use type_aliases::*; mod metadata; mod query_evaluation; /// All traits are available from the root pub use crate::agent_traits::*; pub use crate::chat_completion::traits::*; pub use crate::indexing_traits::*; pub use crate::query_traits::*; pub use crate::token_estimation::*; // Decorators are available from the root pub use crate::indexing_decorators::*; pub mod indexing { pub use crate::indexing_decorators::*; pub use crate::indexing_defaults::*; pub use crate::indexing_stream::IndexingStream; pub use crate::indexing_traits::*; pub use crate::metadata::*; pub use crate::node::*; } pub mod querying { pub use crate::document::*; pub use crate::query::*; pub use crate::query_evaluation::*; pub use crate::query_stream::*; pub use crate::query_traits::*; pub mod search_strategies { pub use crate::search_strategies::*; } } /// Re-export of commonly used dependencies. pub mod prelude; #[cfg(feature = "test-utils")] pub mod test_utils; pub mod util; #[cfg(feature = "metrics")] pub mod metrics; /// Pipeline statistics collection for monitoring and observability pub mod statistics; ================================================ FILE: swiftide-core/src/metadata.rs ================================================ //! Metadata is a key-value store for indexation nodes //! //! Typically metadata is used to extract or generate additional information about the node //! //! Internally it uses a `BTreeMap` to store the key-value pairs, to ensure the data is sorted. use std::collections::{BTreeMap, btree_map::IntoValues}; use serde::Deserializer; use crate::util::debug_long_utf8; #[derive(Clone, Default, PartialEq, Eq)] pub struct Metadata { inner: BTreeMap, } impl std::fmt::Debug for Metadata { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_map() .entries( self.inner .iter() .map(|(k, v): (&String, &serde_json::Value)| { let fvalue = v.as_str().map_or_else( || debug_long_utf8(v.to_string(), 100), ToString::to_string, ); (k, fvalue) }), ) .finish() } } impl Metadata { pub fn iter(&self) -> impl Iterator { self.inner.iter() } pub fn insert(&mut self, key: K, value: V) where K: Into, V: Into, { self.inner.insert(key.into(), value.into()); } pub fn get(&self, key: impl AsRef) -> Option<&serde_json::Value> { self.inner.get(key.as_ref()) } pub fn into_values(self) -> IntoValues { self.inner.into_values() } pub fn keys(&self) -> impl Iterator { self.inner.keys().map(String::as_str) } pub fn values(&self) -> impl Iterator { self.inner.values() } pub fn is_empty(&self) -> bool { self.inner.is_empty() } } impl Extend<(K, V)> for Metadata where K: Into, V: Into, { fn extend>(&mut self, iter: T) { self.inner .extend(iter.into_iter().map(|(k, v)| (k.into(), v.into()))); } } impl From> for Metadata where K: Into, V: Into, { fn from(items: Vec<(K, V)>) -> Self { let inner = items .into_iter() .map(|(k, v)| (k.into(), v.into())) .collect(); Metadata { inner } } } impl From<(K, V)> for Metadata where K: Into, V: Into, { fn from(items: (K, V)) -> Self { let sliced: [(K, V); 1] = [items]; let inner = sliced .into_iter() .map(|(k, v)| (k.into(), v.into())) .collect(); Metadata { inner } } } impl<'a, K, V> From<&'a [(K, V)]> for Metadata where K: Into + Clone, V: Into + Clone, { fn from(items: &'a [(K, V)]) -> Self { let inner = items .iter() .cloned() .map(|(k, v)| (k.into(), v.into())) .collect(); Metadata { inner } } } impl From<[(K, V); N]> for Metadata where K: Ord + Into, V: Into, { fn from(mut arr: [(K, V); N]) -> Self { if N == 0 { return Metadata { inner: BTreeMap::new(), }; } arr.sort_by(|a, b| a.0.cmp(&b.0)); let inner: BTreeMap = arr.into_iter().map(|(k, v)| (k.into(), v.into())).collect(); Metadata { inner } } } impl IntoIterator for Metadata { type Item = (String, serde_json::Value); type IntoIter = std::collections::btree_map::IntoIter; fn into_iter(self) -> Self::IntoIter { self.inner.into_iter() } } impl<'iter> IntoIterator for &'iter Metadata { type Item = (&'iter String, &'iter serde_json::Value); type IntoIter = std::collections::btree_map::Iter<'iter, String, serde_json::Value>; fn into_iter(self) -> Self::IntoIter { self.inner.iter() } } impl<'de> serde::Deserialize<'de> for Metadata { fn deserialize>(deserializer: D) -> Result { BTreeMap::deserialize(deserializer).map(|inner| Metadata { inner }) } } impl serde::Serialize for Metadata { fn serialize(&self, serializer: S) -> Result { self.inner.serialize(serializer) } } #[cfg(test)] mod tests { use super::*; use serde_json::json; #[test] fn test_insert_and_get() { let mut metadata = Metadata::default(); let key = "key"; let value = "value"; metadata.insert(key, "value"); assert_eq!(metadata.get(key).unwrap().as_str(), Some(value)); } #[test] fn test_iter() { let mut metadata = Metadata::default(); metadata.insert("key1", json!("value1")); metadata.insert("key2", json!("value2")); let mut iter = metadata.iter(); assert_eq!(iter.next(), Some((&"key1".to_string(), &json!("value1")))); assert_eq!(iter.next(), Some((&"key2".to_string(), &json!("value2")))); assert_eq!(iter.next(), None); } #[test] fn test_extend() { let mut metadata = Metadata::default(); metadata.extend(vec![("key1", json!("value1")), ("key2", json!("value2"))]); assert_eq!(metadata.get("key1"), Some(&json!("value1"))); assert_eq!(metadata.get("key2"), Some(&json!("value2"))); } #[test] fn test_from_vec() { let metadata = Metadata::from(vec![("key1", json!("value1")), ("key2", json!("value2"))]); assert_eq!(metadata.get("key1"), Some(&json!("value1"))); assert_eq!(metadata.get("key2"), Some(&json!("value2"))); } #[test] fn test_into_values() { let mut metadata = Metadata::default(); metadata.insert("key1", json!("value1")); metadata.insert("key2", json!("value2")); let values: Vec<_> = metadata.into_values().collect(); assert_eq!(values, vec![json!("value1"), json!("value2")]); } } ================================================ FILE: swiftide-core/src/metrics.rs ================================================ use std::sync::OnceLock; use metrics::{IntoLabels, Label, counter, describe_counter}; static METRICS_INIT: OnceLock = OnceLock::new(); /// Lazily describes all the metrics used in this module once pub fn lazy_init() { METRICS_INIT.get_or_init(|| { describe_counter!("swiftide.usage.prompt_tokens", "token usage for the prompt"); describe_counter!( "swiftide.usage.completion_tokens", "token usage for the completion" ); describe_counter!("swiftide.usage.total_tokens", "total token usage"); true }); } /// Emits usage metrics for a language model pub fn emit_usage( model: &str, prompt_tokens: u64, completion_tokens: u64, total_tokens: u64, custom_metadata: Option, ) { let model = model.to_string(); let mut metadata = vec![]; if let Some(custom_metadata) = custom_metadata { metadata.extend(custom_metadata.into_labels()); } metadata.push(Label::new("model", model)); lazy_init(); counter!("swiftide.usage.prompt_tokens", metadata.iter()).increment(prompt_tokens); counter!("swiftide.usage.completion_tokens", metadata.iter()).increment(completion_tokens); counter!("swiftide.usage.total_tokens", metadata.iter()).increment(total_tokens); } ================================================ FILE: swiftide-core/src/node.rs ================================================ //! This module defines the `Node` struct and its associated methods. //! //! `Node` represents a unit of data in the indexing process, containing metadata, //! the data chunk itself, and an optional vector representation. //! //! # Overview //! //! The `Node` struct is designed to encapsulate all necessary information for a single //! unit of data being processed in the indexing pipeline. It includes fields for an identifier, //! file path, data chunk, optional vector representation, and metadata. //! //! The struct provides methods to convert the node into an embeddable string format and to //! calculate a hash value for the node based on its path and chunk. //! //! # Usage //! //! The `Node` struct is used throughout the indexing pipeline to represent and process //! individual units of data. It is particularly useful in scenarios where metadata and data chunks //! need to be processed together. use std::{ collections::HashMap, fmt::Debug, hash::{Hash, Hasher}, os::unix::ffi::OsStrExt, path::PathBuf, }; use derive_builder::Builder; use itertools::Itertools; use serde::{Deserialize, Serialize}; use crate::{Embedding, SparseEmbedding, metadata::Metadata}; /// Helper trait for types that can be used as data chunks in a `Node`. /// For now always expects an owned value /// /// A chunk must be able to yield its bytes, be cloned (not while streaming), and be sent across /// threads. pub trait Chunk: Clone + Send + Sync + Debug + AsRef<[u8]> + 'static {} impl Chunk for T where T: Clone + Send + Sync + Debug + AsRef<[u8]> + 'static {} /// Represents a unit of data in the indexing process. /// /// `Node` encapsulates all necessary information for a single unit of data being processed /// in the indexing pipeline. It includes fields for an identifier, file path, data chunk, optional /// vector representation, and metadata. #[derive(Default, Clone, Serialize, Deserialize, PartialEq, Builder)] #[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] pub struct Node { /// File path associated with the node. #[builder(default)] pub path: PathBuf, /// Data chunk contained in the node. pub chunk: T, /// Optional vector representation of embedded data. #[builder(default)] pub vectors: Option>, /// Optional sparse vector representation of embedded data. #[builder(default)] pub sparse_vectors: Option>, /// Metadata associated with the node. #[builder(default)] pub metadata: Metadata, /// Mode of embedding data Chunk and Metadata #[builder(default)] pub embed_mode: EmbedMode, /// Size of the input this node was originally derived from in bytes #[builder(default)] pub original_size: usize, /// Offset of the chunk relative to the start of the input this node was originally derived /// from in bytes #[builder(default)] pub offset: usize, } pub type TextNode = Node; impl NodeBuilder { pub fn maybe_sparse_vectors( &mut self, sparse_vectors: Option>, ) -> &mut Self { self.sparse_vectors = Some(sparse_vectors); self } pub fn maybe_vectors( &mut self, vectors: Option>, ) -> &mut Self { self.vectors = Some(vectors); self } } impl Debug for Node { /// Formats the node for debugging purposes. /// /// This method is used to provide a human-readable representation of the node when debugging. /// The vector field is displayed as the number of elements in the vector if present. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Node") .field("id", &self.id()) .field("path", &self.path) .field("chunk", &self.chunk) .field("metadata", &self.metadata) .field( "vectors", &self .vectors .iter() .flat_map(HashMap::iter) .map(|(embed_type, vec)| format!("'{embed_type}': {}", vec.len())) .join(","), ) .field( "sparse_vectors", &self .sparse_vectors .iter() .flat_map(HashMap::iter) .map(|(embed_type, vec)| { format!( "'{embed_type}': indices({}), values({})", vec.indices.len(), vec.values.len() ) }) .join(","), ) .field("embed_mode", &self.embed_mode) .finish() } } impl Node { /// Builds a new instance of `Node`, returning a `NodeBuilder`. Copies /// over the fields from the provided `Node`. pub fn build_from_other(node: &Node) -> NodeBuilder { NodeBuilder::default() .path(node.path.clone()) .chunk(node.chunk.clone()) .metadata(node.metadata.clone()) .maybe_vectors(node.vectors.clone()) .maybe_sparse_vectors(node.sparse_vectors.clone()) .embed_mode(node.embed_mode) .original_size(node.original_size) .offset(node.offset) .to_owned() } /// Creates a new instance of `NodeBuilder.` pub fn builder() -> NodeBuilder { NodeBuilder::default() } /// Creates a new instance of `Node` with the specified data chunk. /// /// The other fields are set to their default values. pub fn new(chunk: impl Into) -> Node { let chunk = chunk.into(); let original_size = chunk.len(); Node { chunk, original_size, ..Default::default() } } pub fn with_metadata(&mut self, metadata: impl Into) -> &mut Self { self.metadata = metadata.into(); self } pub fn with_vectors( &mut self, vectors: impl Into>, ) -> &mut Self { self.vectors = Some(vectors.into()); self } pub fn with_sparse_vectors( &mut self, sparse_vectors: impl Into>, ) -> &mut Self { self.sparse_vectors = Some(sparse_vectors.into()); self } /// Retrieve the identifier of the node. /// /// Calculates the identifier of the node based on its path and chunk as bytes, returning a /// UUID (v3). /// /// WARN: Does not memoize the id. Use sparingly. pub fn id(&self) -> uuid::Uuid { // Calculate the identifier based on the path and chunk as bytes let bytes = [self.path.as_os_str().as_bytes(), self.chunk.as_ref()].concat(); uuid::Uuid::new_v3(&uuid::Uuid::NAMESPACE_OID, &bytes) } } impl Node { /// Creates embeddable data depending on chosen `EmbedMode`. /// /// # Returns /// /// Embeddable data mapped to their `EmbeddedField`. pub fn as_embeddables(&self) -> Vec<(EmbeddedField, String)> { // TODO: Cow and borrow the inner data + generic let mut embeddables = Vec::new(); if self.embed_mode == EmbedMode::SingleWithMetadata || self.embed_mode == EmbedMode::Both { embeddables.push((EmbeddedField::Combined, self.combine_chunk_with_metadata())); } if self.embed_mode == EmbedMode::PerField || self.embed_mode == EmbedMode::Both { embeddables.push((EmbeddedField::Chunk, self.chunk.clone())); for (name, value) in &self.metadata { let value = value .as_str() .map_or_else(|| value.to_string(), ToString::to_string); embeddables.push((EmbeddedField::Metadata(name.clone()), value)); } } embeddables } /// Converts the node into an [`self::EmbeddedField::Combined`] type of embeddable. /// /// This embeddable format consists of the metadata formatted as key-value pairs, each on a new /// line, followed by the data chunk. /// /// # Returns /// /// A string representing the embeddable format of the node. fn combine_chunk_with_metadata(&self) -> String { // Metadata formatted by newlines joined with the chunk let metadata = self .metadata .iter() .map(|(k, v)| { let v = v .as_str() .map_or_else(|| v.to_string(), ToString::to_string); format!("{k}: {v}") }) .collect::>() .join("\n"); format!("{}\n{}", metadata, self.chunk) } } impl Hash for Node { /// Hashes the node based on its path and chunk. /// /// This method is used by the `calculate_hash` method to generate a hash value for the node. fn hash(&self, state: &mut H) { self.path.hash(state); self.chunk.hash(state); } } impl> From for Node { fn from(value: T) -> Self { let value: String = value.into(); Node::::new(value) } } /// Embed mode of the pipeline. #[derive(Copy, Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub enum EmbedMode { #[default] /// Embedding Chunk of data combined with Metadata. SingleWithMetadata, /// Embedding Chunk of data and every Metadata separately. PerField, /// Embedding Chunk of data and every Metadata separately and Chunk of data combined with /// Metadata. Both, } /// Type of Embeddable stored in model. #[derive( Clone, Default, Serialize, Deserialize, PartialEq, Eq, Hash, strum_macros::Display, Debug, )] pub enum EmbeddedField { #[default] /// Embeddable created from Chunk of data combined with Metadata. Combined, /// Embeddable created from Chunk of data only. Chunk, /// Embeddable created from Metadata. /// String stores Metadata name. #[strum(to_string = "Metadata: {0}")] Metadata(String), } impl EmbeddedField { /// Returns the name of the field when it would be a sparse vector pub fn sparse_field_name(&self) -> String { format!("{self}_sparse") } /// Returns the name of the field when it would be a dense vector pub fn field_name(&self) -> String { format!("{self}") } } #[allow(clippy::from_over_into)] impl Into for EmbeddedField { fn into(self) -> String { self.to_string() } } #[cfg(test)] mod tests { use super::*; use test_case::test_case; #[test_case(&EmbeddedField::Combined, ["Combined", "Combined_sparse"])] #[test_case(&EmbeddedField::Chunk, ["Chunk", "Chunk_sparse"])] #[test_case(&EmbeddedField::Metadata("test".into()), ["Metadata: test", "Metadata: test_sparse"])] fn field_name_tests(embedded_field: &EmbeddedField, expected: [&str; 2]) { assert_eq!(embedded_field.field_name(), expected[0]); assert_eq!(embedded_field.sparse_field_name(), expected[1]); } #[test] fn test_debugging_node_with_utf8_char_boundary() { let node = Node::from("🦀".repeat(101)); // Single char let _ = format!("{node:?}"); // With invalid char boundary let node = Node::from("Jürgen".repeat(100)); let _ = format!("{node:?}"); } #[test] fn test_build_from_other_without_vectors() { let original_node = Node::from("test_chunk") .with_metadata(Metadata::default()) .with_vectors(HashMap::new()) .with_sparse_vectors(HashMap::new()) .to_owned(); let builder = Node::build_from_other(&original_node); let new_node = builder.build().unwrap(); assert_eq!(original_node, new_node); } #[test] fn test_build_from_other_with_vectors() { let mut vectors = HashMap::new(); vectors.insert(EmbeddedField::Chunk, Embedding::default()); let mut sparse_vectors = HashMap::new(); sparse_vectors.insert( EmbeddedField::Chunk, SparseEmbedding { indices: vec![], values: vec![], }, ); let original_node = Node::from("test_chunk") .with_metadata(Metadata::default()) .with_vectors(vectors.clone()) .with_sparse_vectors(sparse_vectors.clone()) .to_owned(); let builder = Node::build_from_other(&original_node); let new_node = builder.build().unwrap(); assert_eq!(original_node, new_node); } } ================================================ FILE: swiftide-core/src/prelude.rs ================================================ pub use anyhow::{Context as _, Result}; pub use async_trait::async_trait; pub use derive_builder::Builder; pub use futures_util::{StreamExt, TryStreamExt}; pub use std::sync::Arc; pub use tracing::Instrument; #[cfg(feature = "test-utils")] pub use crate::assert_default_prompt_snapshot; ================================================ FILE: swiftide-core/src/prompt.rs ================================================ //! Prompts templating and management //! //! Prompts are first class citizens in Swiftide and use [tera] under the hood. tera //! uses jinja style templates which allows for a lot of flexibility. //! //! Conceptually, a [Prompt] is something you send to i.e. //! [`SimplePrompt`][crate::SimplePrompt]. A prompt can have //! added context for substitution and other templating features. //! //! Transformers in Swiftide come with default prompts, and they can be customized or replaced as //! needed. //! //! [`Template`] can be added with [`Template::try_compiled_from_str`]. Prompts can also be //! created on the fly from anything that implements [`Into`]. Compiled prompts are stored //! in an internal repository. //! //! Additionally, `Template::String` and `Template::Static` can be used to create //! templates on the fly as well. //! //! It's recommended to precompile your templates. //! //! # Example //! //! ``` //! #[tokio::main] //! # async fn main() { //! # use swiftide_core::prompt::Prompt; //! let prompt = Prompt::from("hello {{world}}").with_context_value("world", "swiftide"); //! //! assert_eq!(prompt.render().unwrap(), "hello swiftide"); //! # } //! ``` use std::{ borrow::Cow, sync::{LazyLock, RwLock}, }; use anyhow::{Context as _, Result}; use tera::Tera; use crate::node::TextNode; /// A Prompt can be used with large language models to prompt. #[derive(Clone, Debug)] pub struct Prompt { template_ref: TemplateRef, context: Option, } /// References a to be rendered template /// Either a one-off template or a tera template #[derive(Clone, Debug)] enum TemplateRef { OneOff(Cow<'static, str>), Tera(Cow<'static, str>), } pub static SWIFTIDE_TERA: LazyLock> = LazyLock::new(|| RwLock::new(Tera::default())); impl Prompt { /// Extend the swiftide repository with another Tera instance. /// /// You can use this to add your own templates, functions and partials. /// /// # Panics /// /// Panics if the `RWLock` is poisoned. /// /// # Errors /// /// Errors if the `Tera` instance cannot be extended. pub fn extend(other: &Tera) -> Result<()> { let mut swiftide_tera = SWIFTIDE_TERA.write().unwrap(); swiftide_tera.extend(other)?; Ok(()) } /// Create a new prompt from a compiled template that is present in the Tera repository pub fn from_compiled_template(name: impl Into>) -> Prompt { Prompt { template_ref: TemplateRef::Tera(name.into()), context: None, } } /// Adds an `ingestion::Node` to the context of the Prompt #[must_use] pub fn with_node(mut self, node: &TextNode) -> Self { let context = self.context.get_or_insert_with(tera::Context::default); context.insert("node", &node); self } /// Adds anything that implements [Into], like `Serialize` to the Prompt #[must_use] pub fn with_context(mut self, new_context: impl Into) -> Self { let context = self.context.get_or_insert_with(tera::Context::default); context.extend(new_context.into()); self } /// Adds a key-value pair to the context of the Prompt #[must_use] pub fn with_context_value(mut self, key: &str, value: impl Into) -> Self { let context = self.context.get_or_insert_with(tera::Context::default); context.insert(key, &value.into()); self } /// Renders a prompt /// /// If no context is provided, the prompt will be rendered as is. /// /// # Errors /// /// See `Template::render` /// /// # Panics /// /// Panics if the `RWLock` is poisoned. pub fn render(&self) -> Result { if self.context.is_none() && let TemplateRef::OneOff(ref template) = self.template_ref { return Ok(template.to_string()); } let context: Cow<'_, tera::Context> = self .context .as_ref() .map_or_else(|| Cow::Owned(tera::Context::default()), Cow::Borrowed); match &self.template_ref { TemplateRef::OneOff(template) => { tera::Tera::one_off(template.as_ref(), &context, false) .context("Failed to render one-off template") } TemplateRef::Tera(template) => SWIFTIDE_TERA .read() .unwrap() .render(template.as_ref(), &context) .context("Failed to render template"), } } } impl From<&'static str> for Prompt { fn from(prompt: &'static str) -> Self { Prompt { template_ref: TemplateRef::OneOff(prompt.into()), context: None, } } } impl From for Prompt { fn from(prompt: String) -> Self { Prompt { template_ref: TemplateRef::OneOff(prompt.into()), context: None, } } } #[cfg(test)] mod test { use crate::node::Node; use super::*; #[tokio::test] async fn test_prompt() { let prompt: Prompt = "hello {{world}}".into(); let prompt = prompt.with_context_value("world", "swiftide"); assert_eq!(prompt.render().unwrap(), "hello swiftide"); } #[tokio::test] async fn test_prompt_with_node() { let prompt: Prompt = "hello {{node.chunk}}".into(); let node = Node::from("test"); let prompt = prompt.with_node(&node); assert_eq!(prompt.render().unwrap(), "hello test"); } #[tokio::test] async fn test_one_off_from_string() { let mut prompt: Prompt = "hello {{world}}".into(); prompt = prompt.with_context_value("world", "swiftide"); assert_eq!(prompt.render().unwrap(), "hello swiftide"); } #[tokio::test] async fn test_extending_with_custom_repository() { let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap(); custom_tera .add_raw_template("hello", "hello {{world}}") .unwrap(); Prompt::extend(&custom_tera).unwrap(); let prompt = Prompt::from_compiled_template("hello").with_context_value("world", "swiftide"); assert_eq!(prompt.render().unwrap(), "hello swiftide"); } #[tokio::test] async fn test_coercion_to_prompt() { // str let raw: &str = "hello {{world}}"; let prompt: Prompt = raw.into(); assert_eq!( prompt .with_context_value("world", "swiftide") .render() .unwrap(), "hello swiftide" ); let prompt: Prompt = raw.to_string().into(); assert_eq!( prompt .with_context_value("world", "swiftide") .render() .unwrap(), "hello swiftide" ); } #[tokio::test] async fn test_assume_rendered_unless_context_methods_called() { let prompt = Prompt::from("hello {{world}}"); assert_eq!(prompt.render().unwrap(), "hello {{world}}"); } } ================================================ FILE: swiftide-core/src/query.rs ================================================ //! A query is the main object going through a query pipeline //! //! It acts as a statemachine, with the following transitions: //! //! `states::Pending`: No documents have been retrieved //! `states::Retrieved`: Documents have been retrieved //! `states::Answered`: The query has been answered use derive_builder::Builder; use crate::{Embedding, SparseEmbedding, document::Document, util::debug_long_utf8}; /// A query is the main object going through a query pipeline /// /// It acts as a statemachine, with the following transitions: /// /// `states::Pending`: No documents have been retrieved /// `states::Retrieved`: Documents have been retrieved /// `states::Answered`: The query has been answered #[derive(Clone, Default, Builder, PartialEq)] #[builder(setter(into))] pub struct Query { original: String, #[builder(default = "self.original.clone().unwrap_or_default()")] current: String, #[builder(default = STATE::default())] state: STATE, #[builder(default)] transformation_history: Vec, // TODO: How would this work when doing a rollup query? #[builder(default)] pub embedding: Option, #[builder(default)] pub sparse_embedding: Option, /// Documents the query will operate on /// /// A query can retrieve multiple times, accumulating documents #[builder(default)] pub documents: Vec, } impl std::fmt::Debug for Query { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Query") .field( "original", &debug_long_utf8(&self.original, 100).lines().take(1), ) .field( "current", &debug_long_utf8(&self.current, 100).lines().take(1), ) .field("state", &self.state) .field("num_transformations", &self.transformation_history.len()) .field("embedding", &self.embedding.is_some()) .field("num_documents", &self.documents.len()) .finish() } } impl Query { pub fn builder() -> QueryBuilder { QueryBuilder::default().clone() } /// Return the query it started with pub fn original(&self) -> &str { &self.original } /// Return the current query (or after retrieval!) pub fn current(&self) -> &str { &self.current } fn transition_to(self, new_state: NEWSTATE) -> Query { Query { state: new_state, original: self.original, current: self.current, transformation_history: self.transformation_history, embedding: self.embedding, sparse_embedding: self.sparse_embedding, documents: self.documents, } } #[allow(dead_code)] pub fn history(&self) -> &Vec { &self.transformation_history } /// Returns the current documents that will be used as context for answer generation pub fn documents(&self) -> &[Document] { &self.documents } /// Returns the current documents as mutable pub fn documents_mut(&mut self) -> &mut Vec { &mut self.documents } } impl Query { /// Add retrieved documents and transition to `states::Retrieved` pub fn retrieved_documents(mut self, documents: Vec) -> Query { self.documents.extend(documents.clone()); self.transformation_history .push(TransformationEvent::Retrieved { before: self.current.clone(), after: self.current.clone(), documents, }); let state = states::Retrieved; self.transition_to(state) } } impl Query { pub fn new(query: impl Into) -> Self { Self { original: query.into(), ..Default::default() } } /// Transforms the current query pub fn transformed_query(&mut self, new_query: impl Into) { let new_query = new_query.into(); self.transformation_history .push(TransformationEvent::Transformed { before: self.current.clone(), after: new_query.clone(), }); self.current = new_query; } } impl Query { pub fn new() -> Self { Self::default() } /// Transforms the current response pub fn transformed_response(&mut self, new_response: impl Into) { let new_response = new_response.into(); self.transformation_history .push(TransformationEvent::Transformed { before: self.current.clone(), after: new_response.clone(), }); self.current = new_response; } /// Transition the query to `states::Answered` #[must_use] pub fn answered(mut self, answer: impl Into) -> Query { self.current = answer.into(); let state = states::Answered; self.transition_to(state) } } impl Query { pub fn new() -> Self { Self::default() } /// Returns the answer of the query pub fn answer(&self) -> &str { &self.current } } /// Marker trait for query states pub trait QueryState: Send + Sync + Default {} /// Marker trait for query states that can still retrieve pub trait CanRetrieve: QueryState {} /// States of a query pub mod states { use super::{CanRetrieve, QueryState}; #[derive(Debug, Default, Clone, PartialEq)] /// The query is pending and has not been used pub struct Pending; #[derive(Debug, Default, Clone, PartialEq)] /// Documents have been retrieved pub struct Retrieved; #[derive(Debug, Default, Clone, PartialEq)] /// The query has been answered pub struct Answered; impl QueryState for Pending {} impl QueryState for Retrieved {} impl QueryState for Answered {} impl CanRetrieve for Pending {} impl CanRetrieve for Retrieved {} } impl> From for Query { fn from(original: T) -> Self { Self { original: original.as_ref().to_string(), current: original.as_ref().to_string(), state: states::Pending, ..Default::default() } } } #[derive(Clone, PartialEq)] /// Records changes to a query pub enum TransformationEvent { Transformed { before: String, after: String, }, Retrieved { before: String, after: String, documents: Vec, }, } impl TransformationEvent { /// Returns true if the event is a retrieval pub fn is_retrieval(&self) -> bool { matches!(self, TransformationEvent::Retrieved { .. }) } /// Returns true if the event is a transformation pub fn is_transformation(&self) -> bool { matches!(self, TransformationEvent::Transformed { .. }) } /// Returns the query before the transformation/retrieval pub fn before(&self) -> &str { match self { TransformationEvent::Transformed { before, .. } | TransformationEvent::Retrieved { before, .. } => before, } } /// Returns the query after the transformation/retrieval pub fn after(&self) -> &str { match self { TransformationEvent::Transformed { after, .. } | TransformationEvent::Retrieved { after, .. } => after, } } /// Returns the documents retrieved, if any pub fn documents(&self) -> Option<&[Document]> { match self { TransformationEvent::Retrieved { documents, .. } => Some(documents), TransformationEvent::Transformed { .. } => None, } } } impl std::fmt::Debug for TransformationEvent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TransformationEvent::Transformed { before, after } => { write!( f, "Transformed: {} -> {}", &debug_long_utf8(before, 100), &debug_long_utf8(after, 100) ) } TransformationEvent::Retrieved { before, after, documents, } => { write!( f, "Retrieved: {} -> {}\nDocuments: {:?}", &debug_long_utf8(before, 100), &debug_long_utf8(after, 100), documents.len() ) } } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_query_initial_state() { let query = Query::::from("test query"); assert_eq!(query.original(), "test query"); assert_eq!(query.current(), "test query"); assert_eq!(query.history().len(), 0); } #[test] fn test_query_transformed_query() { let mut query = Query::::from("test query"); query.transformed_query("new query"); assert_eq!(query.current(), "new query"); assert_eq!(query.history().len(), 1); if let TransformationEvent::Transformed { before, after } = &query.history()[0] { assert_eq!(before, "test query"); assert_eq!(after, "new query"); } else { panic!("Unexpected event in history"); } } #[test] fn test_query_retrieved_documents() { let query = Query::::from("test query"); let documents: Vec = vec!["doc1".into(), "doc2".into()]; let query = query.retrieved_documents(documents.clone()); assert_eq!(query.documents(), &documents); assert_eq!(query.history().len(), 1); if let TransformationEvent::Retrieved { before, after, documents: retrieved_docs, } = &query.history()[0] { assert_eq!(before, "test query"); assert_eq!(after, "test query"); assert_eq!(retrieved_docs, &documents); } else { panic!("Unexpected event in history"); } } #[test] fn test_query_transformed_response() { let query = Query::::from("test query"); let documents = vec!["doc1".into(), "doc2".into()]; let mut query = query.retrieved_documents(documents.clone()); query.transformed_response("new response"); assert_eq!(query.current(), "new response"); assert_eq!(query.history().len(), 2); assert_eq!(query.documents(), &documents); assert_eq!(query.original, "test query"); if let TransformationEvent::Transformed { before, after } = &query.history()[1] { assert_eq!(before, "test query"); assert_eq!(after, "new response"); } else { panic!("Unexpected event in history"); } } #[test] fn test_query_answered() { let query = Query::::from("test query"); let documents = vec!["doc1".into(), "doc2".into()]; let query = query.retrieved_documents(documents); let query = query.answered("the answer"); assert_eq!(query.answer(), "the answer"); } } ================================================ FILE: swiftide-core/src/query_evaluation.rs ================================================ use crate::querying::{Query, states}; /// Wraps a query for evaluation. Used by the [`crate::query_traits::EvaluateQuery`] trait. pub enum QueryEvaluation { /// Retrieve documents RetrieveDocuments(Query), /// Answer the query AnswerQuery(Query), } impl std::fmt::Debug for QueryEvaluation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { QueryEvaluation::RetrieveDocuments(query) => { write!(f, "RetrieveDocuments({query:?})") } QueryEvaluation::AnswerQuery(query) => write!(f, "AnswerQuery({query:?})"), } } } impl From> for QueryEvaluation { fn from(val: Query) -> Self { QueryEvaluation::RetrieveDocuments(val) } } impl From> for QueryEvaluation { fn from(val: Query) -> Self { QueryEvaluation::AnswerQuery(val) } } // TODO: must be a nicer way, maybe not needed and full encapsulation is better anyway impl QueryEvaluation { pub fn retrieve_documents_query(self) -> Option> { if let QueryEvaluation::RetrieveDocuments(query) = self { Some(query) } else { None } } pub fn answer_query(self) -> Option> { if let QueryEvaluation::AnswerQuery(query) = self { Some(query) } else { None } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_from_retrieved() { let query = Query::::new(); // Assuming Query has a new() method let evaluation = QueryEvaluation::from(query.clone()); match evaluation { QueryEvaluation::RetrieveDocuments(q) => assert_eq!(q, query), QueryEvaluation::AnswerQuery(_) => panic!("Unexpected QueryEvaluation variant"), } } #[test] fn test_from_answered() { let query = Query::::new(); // Assuming Query has a new() method let evaluation = QueryEvaluation::from(query.clone()); match evaluation { QueryEvaluation::AnswerQuery(q) => assert_eq!(q, query), QueryEvaluation::RetrieveDocuments(_) => panic!("Unexpected QueryEvaluation variant"), } } #[test] fn test_retrieve_documents_query() { let query = Query::::new(); // Assuming Query has a new() method let evaluation = QueryEvaluation::RetrieveDocuments(query.clone()); match evaluation.retrieve_documents_query() { Some(q) => assert_eq!(q, query), None => panic!("Expected a query, got None"), } } #[test] fn test_answer_query() { let query = Query::::new(); // Assuming Query has a new() method let evaluation = QueryEvaluation::AnswerQuery(query.clone()); match evaluation.answer_query() { Some(q) => assert_eq!(q, query), None => panic!("Expected a query, got None"), } } } ================================================ FILE: swiftide-core/src/query_stream.rs ================================================ //! Internally used by a query pipeline //! //! Has a sender and receiver to initialize the stream use anyhow::Result; use std::pin::Pin; use tokio::sync::mpsc::Sender; use tokio_stream::wrappers::ReceiverStream; use futures_util::stream::Stream; pub use futures_util::{StreamExt, TryStreamExt}; use crate::{query::QueryState, querying::Query}; /// Internally used by a query pipeline /// /// Has a sender and receiver to initialize the stream #[pin_project::pin_project] pub struct QueryStream<'stream, STATE: 'stream + QueryState> { #[pin] pub(crate) inner: Pin>> + Send + 'stream>>, #[pin] pub sender: Option>>>, } impl<'stream, STATE: QueryState + 'stream> Default for QueryStream<'stream, STATE> { fn default() -> Self { let (sender, receiver) = tokio::sync::mpsc::channel(1000); Self { inner: ReceiverStream::new(receiver).boxed(), sender: Some(sender), } } } impl Stream for QueryStream<'_, STATE> { type Item = Result>; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); this.inner.poll_next(cx) } } impl From>> + Send>>> for QueryStream<'_, STATE> { fn from(val: Pin>> + Send>>) -> Self { QueryStream { inner: val, sender: None, } } } ================================================ FILE: swiftide-core/src/query_traits.rs ================================================ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use dyn_clone::DynClone; use crate::{ query::{ Query, states::{self, Retrieved}, }, querying::QueryEvaluation, }; #[cfg(feature = "test-utils")] use mockall::{mock, predicate::str}; /// Can transform queries before retrieval #[async_trait] pub trait TransformQuery: Send + Sync + DynClone { async fn transform_query( &self, query: Query, ) -> Result>; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!(TransformQuery); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub TransformQuery {} #[async_trait] impl TransformQuery for TransformQuery { async fn transform_query( &self, query: Query, ) -> Result>; fn name(&self) -> &'static str; } impl Clone for TransformQuery { fn clone(&self) -> Self; } } #[async_trait] impl TransformQuery for F where F: Fn(Query) -> Result> + Send + Sync + Clone, { async fn transform_query( &self, query: Query, ) -> Result> { (self)(query) } } #[async_trait] impl TransformQuery for Box { async fn transform_query( &self, query: Query, ) -> Result> { self.as_ref().transform_query(query).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl TransformQuery for Arc { async fn transform_query( &self, query: Query, ) -> Result> { self.as_ref().transform_query(query).await } fn name(&self) -> &'static str { self.as_ref().name() } } /// A search strategy for the query pipeline pub trait SearchStrategy: Clone + Send + Sync + Default {} /// Can retrieve documents given a `SearchStrategy` #[async_trait] pub trait Retrieve: Send + Sync + DynClone { async fn retrieve( &self, search_strategy: &S, query: Query, ) -> Result>; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!( Retrieve); #[async_trait] impl Retrieve for Box> { async fn retrieve( &self, search_strategy: &S, query: Query, ) -> Result> { self.as_ref().retrieve(search_strategy, query).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl Retrieve for Arc> { async fn retrieve( &self, search_strategy: &S, query: Query, ) -> Result> { self.as_ref().retrieve(search_strategy, query).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl Retrieve for F where S: SearchStrategy, F: Fn(&S, Query) -> Result> + Send + Sync + Clone, { async fn retrieve( &self, search_strategy: &S, query: Query, ) -> Result> { (self)(search_strategy, query) } } /// Can transform a response after retrieval #[async_trait] pub trait TransformResponse: Send + Sync + DynClone { async fn transform_response(&self, query: Query) -> Result>; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!(TransformResponse); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub TransformResponse {} #[async_trait] impl TransformResponse for TransformResponse { async fn transform_response(&self, query: Query) -> Result>; fn name(&self) -> &'static str; } impl Clone for TransformResponse { fn clone(&self) -> Self; } } #[async_trait] impl TransformResponse for F where F: Fn(Query) -> Result> + Send + Sync + Clone, { async fn transform_response(&self, query: Query) -> Result> { (self)(query) } } #[async_trait] impl TransformResponse for Box { async fn transform_response(&self, query: Query) -> Result> { self.as_ref().transform_response(query).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl TransformResponse for Arc { async fn transform_response(&self, query: Query) -> Result> { self.as_ref().transform_response(query).await } fn name(&self) -> &'static str { self.as_ref().name() } } /// Can answer the original query #[async_trait] pub trait Answer: Send + Sync + DynClone { async fn answer(&self, query: Query) -> Result>; fn name(&self) -> &'static str { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } } dyn_clone::clone_trait_object!(Answer); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub Answer {} #[async_trait] impl Answer for Answer { async fn answer(&self, query: Query) -> Result>; fn name(&self) -> &'static str; } impl Clone for Answer { fn clone(&self) -> Self; } } #[async_trait] impl Answer for F where F: Fn(Query) -> Result> + Send + Sync + Clone, { async fn answer(&self, query: Query) -> Result> { (self)(query) } } #[async_trait] impl Answer for Box { async fn answer(&self, query: Query) -> Result> { self.as_ref().answer(query).await } fn name(&self) -> &'static str { self.as_ref().name() } } #[async_trait] impl Answer for Arc { async fn answer(&self, query: Query) -> Result> { self.as_ref().answer(query).await } fn name(&self) -> &'static str { self.as_ref().name() } } /// Evaluates a query /// /// An evaluator needs to be able to respond to each step in the query pipeline #[async_trait] pub trait EvaluateQuery: Send + Sync + DynClone { async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()>; } dyn_clone::clone_trait_object!(EvaluateQuery); #[cfg(feature = "test-utils")] mock! { #[derive(Debug)] pub EvaluateQuery {} #[async_trait] impl EvaluateQuery for EvaluateQuery { async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()>; } impl Clone for EvaluateQuery { fn clone(&self) -> Self; } } #[async_trait] impl EvaluateQuery for Box { async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()> { self.as_ref().evaluate(evaluation).await } } #[async_trait] impl EvaluateQuery for Arc { async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()> { self.as_ref().evaluate(evaluation).await } } ================================================ FILE: swiftide-core/src/search_strategies/custom_strategy.rs ================================================ //! Implements a flexible vector search strategy framework using closure-based configuration. //! Supports both synchronous and asynchronous query generation for different retrieval backends. use crate::querying::{self, Query, states}; use anyhow::{Result, anyhow}; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::sync::Arc; // TODO: Should be possible to remove the static bounds and allow Q as borrowed with some fu // Function type for generating retriever-specific queries type QueryGenerator = Arc) -> Result + Send + Sync>; // Function type for async query generation type AsyncQueryGenerator = Arc< dyn Fn(&Query) -> Pin> + Send>> + Send + Sync, >; /// Implements the strategy pattern for vector similarity search, allowing retrieval backends /// to define custom query generation logic through closures. pub struct CustomStrategy { query: Option>, async_query: Option>, _marker: PhantomData, } impl querying::SearchStrategy for CustomStrategy {} impl Default for CustomStrategy { fn default() -> Self { Self { query: None, async_query: None, _marker: PhantomData, } } } impl Clone for CustomStrategy { fn clone(&self) -> Self { Self { query: self.query.clone(), async_query: self.async_query.clone(), _marker: PhantomData, } } } impl CustomStrategy { /// Creates a new strategy with a synchronous query generator. pub fn from_query( query: impl Fn(&Query) -> Result + Send + Sync + 'static, ) -> Self { Self { query: Some(Arc::new(query)), async_query: None, _marker: PhantomData, } } /// Creates a new strategy with an asynchronous query generator. pub fn from_async_query( query: impl Fn(&Query) -> F + Send + Sync + 'static, ) -> Self where F: Future> + Send + 'static, { Self { query: None, async_query: Some(Arc::new(move |q| Box::pin(query(q)))), _marker: PhantomData, } } /// Generates a query using either the sync or async generator. /// Returns error if no query generator is set. /// /// # Errors /// Returns an error if: /// * No query generator has been configured /// * The configured query generator fails during query generation pub async fn build_query(&self, query_node: &Query) -> Result { match (&self.query, &self.async_query) { (Some(query_fn), _) => query_fn(query_node), (_, Some(async_fn)) => async_fn(query_node).await, _ => Err(anyhow!("No query function has been set.")), } } } ================================================ FILE: swiftide-core/src/search_strategies/hybrid_search.rs ================================================ use derive_builder::Builder; use crate::{indexing::EmbeddedField, querying}; use super::{DEFAULT_TOP_K, DEFAULT_TOP_N, SearchFilter}; /// A hybrid search strategy that combines a similarity search with a /// keyword search / sparse search. /// /// Defaults to a a maximum of 10 documents and `EmbeddedField::Combined` for the field(s). #[derive(Debug, Clone, Builder)] #[builder(setter(into))] pub struct HybridSearch { /// Maximum number of documents to return #[builder(default)] top_k: u64, /// Maximum number of documents to return per query #[builder(default)] top_n: u64, /// The field to use for the dense vector #[builder(default)] dense_vector_field: EmbeddedField, /// The field to use for the sparse vector /// TODO: I.e. lancedb does not use sparse embeddings for hybrid search #[builder(default)] sparse_vector_field: EmbeddedField, #[builder(default)] filter: Option, } impl querying::SearchStrategy for HybridSearch {} impl Default for HybridSearch { fn default() -> Self { Self { top_k: DEFAULT_TOP_K, top_n: DEFAULT_TOP_N, dense_vector_field: EmbeddedField::Combined, sparse_vector_field: EmbeddedField::Combined, filter: None, } } } impl HybridSearch { /// Creates a new hybrid search strategy that uses the provided filter pub fn from_filter(filter: FILTER) -> Self { Self { filter: Some(filter), ..Default::default() } } pub fn with_filter( self, filter: NEWFILTER, ) -> HybridSearch { HybridSearch { top_k: self.top_k, top_n: self.top_n, dense_vector_field: self.dense_vector_field, sparse_vector_field: self.sparse_vector_field, filter: Some(filter), } } /// Sets the maximum amount of total documents retrieved pub fn with_top_k(&mut self, top_k: u64) -> &mut Self { self.top_k = top_k; self } /// Returns the maximum amount of total documents to be retrieved pub fn top_k(&self) -> u64 { self.top_k } /// Sets the maximum amount of documents to be retrieved /// per individual query pub fn with_top_n(&mut self, top_n: u64) -> &mut Self { self.top_n = top_n; self } /// Returns the maximum amount of documents per query pub fn top_n(&self) -> u64 { self.top_n } /// Sets the vector field for the dense vector /// /// Defaults to `EmbeddedField::Combined` pub fn with_dense_vector_field( &mut self, dense_vector_field: impl Into, ) -> &mut Self { self.dense_vector_field = dense_vector_field.into(); self } /// Returns the field for the dense vector pub fn dense_vector_field(&self) -> &EmbeddedField { &self.dense_vector_field } /// Sets the vector field for the sparse vector (if applicable) /// /// Defaults to `EmbeddedField::Combined` pub fn with_sparse_vector_field( &mut self, sparse_vector_field: impl Into, ) -> &mut Self { self.sparse_vector_field = sparse_vector_field.into(); self } /// Returns the field for the dense vector pub fn sparse_vector_field(&self) -> &EmbeddedField { &self.sparse_vector_field } pub fn filter(&self) -> Option<&FILTER> { self.filter.as_ref() } } ================================================ FILE: swiftide-core/src/search_strategies/mod.rs ================================================ //! Search strategies provide a generic way for Retrievers to implement their //! search in various ways. //! //! The strategy is also yielded to the Retriever and can contain addition configuration mod custom_strategy; mod hybrid_search; mod similarity_single_embedding; pub(crate) const DEFAULT_TOP_K: u64 = 10; pub(crate) const DEFAULT_TOP_N: u64 = 10; pub use custom_strategy::*; pub use hybrid_search::*; pub use similarity_single_embedding::*; pub trait SearchFilter: Clone + Sync + Send {} #[cfg(feature = "qdrant")] impl SearchFilter for qdrant_client::qdrant::Filter {} // When no filters are applied impl SearchFilter for () {} // Lancedb uses a string filter impl SearchFilter for String {} ================================================ FILE: swiftide-core/src/search_strategies/similarity_single_embedding.rs ================================================ use crate::querying; use super::{DEFAULT_TOP_K, SearchFilter}; /// A simple, single vector similarity search where it takes the embedding on the current query /// and returns `top_k` documents. /// /// Can optionally be used with a filter. #[derive(Debug, Clone)] pub struct SimilaritySingleEmbedding { /// Maximum number of documents to return top_k: u64, filter: Option, } impl querying::SearchStrategy for SimilaritySingleEmbedding {} impl Default for SimilaritySingleEmbedding { fn default() -> Self { Self { top_k: DEFAULT_TOP_K, filter: None, } } } impl SimilaritySingleEmbedding<()> { /// Set an optional filter to be used in the query pub fn into_concrete_filter(&self) -> SimilaritySingleEmbedding { SimilaritySingleEmbedding:: { top_k: self.top_k, filter: None, } } } impl SimilaritySingleEmbedding { pub fn from_filter(filter: FILTER) -> Self { Self { filter: Some(filter), ..Default::default() } } /// Set the maximum amount of documents to be returned pub fn with_top_k(&mut self, top_k: u64) -> &mut Self { self.top_k = top_k; self } /// Returns the maximum of documents to be returned pub fn top_k(&self) -> u64 { self.top_k } /// Set an optional filter to be used in the query pub fn with_filter( self, filter: NEWFILTER, ) -> SimilaritySingleEmbedding { SimilaritySingleEmbedding:: { top_k: self.top_k, filter: Some(filter), } } pub fn filter(&self) -> &Option { &self.filter } } ================================================ FILE: swiftide-core/src/statistics.rs ================================================ //! Pipeline statistics collection //! //! This module provides comprehensive monitoring and observability for pipelines, //! including node counts, token usage, and timing information. //! //! # Example //! //! ```rust,ignore //! use swiftide::indexing::Pipeline; //! //! let pipeline = Pipeline::from_loader(loader) //! .then(transformer) //! .store(storage); //! //! // Run pipeline //! pipeline.run().await?; //! //! // Get statistics //! let stats = pipeline.stats(); //! println!("Processed {} nodes in {:?}", stats.nodes_processed, stats.duration()); //! ``` use std::{ collections::HashMap, sync::{ Mutex, MutexGuard, atomic::{AtomicU64, Ordering}, }, time::{Duration, Instant}, }; const TWO_POW_32_F64: f64 = 4_294_967_296.0; fn lock_recover(mutex: &Mutex) -> MutexGuard<'_, T> { mutex .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) } fn u64_to_f64(value: u64) -> f64 { let upper = u32::try_from(value >> 32).expect("upper 32 bits always fit in u32"); let lower = u32::try_from(value & u64::from(u32::MAX)).expect("lower 32 bits always fit in u32"); f64::from(upper) * TWO_POW_32_F64 + f64::from(lower) } /// Statistics for a single model's usage #[derive(Debug, Clone, Default, PartialEq)] pub struct ModelUsage { /// Number of prompt tokens used pub prompt_tokens: u64, /// Number of completion tokens used pub completion_tokens: u64, /// Total tokens used (prompt + completion) pub total_tokens: u64, /// Number of requests made to this model pub request_count: u64, } impl ModelUsage { /// Creates a new `ModelUsage` with zero counts #[must_use] pub fn new() -> Self { Self::default() } /// Records token usage for a single request pub fn record(&mut self, prompt_tokens: u64, completion_tokens: u64) { self.prompt_tokens += prompt_tokens; self.completion_tokens += completion_tokens; self.total_tokens += prompt_tokens + completion_tokens; self.request_count += 1; } } /// A snapshot of pipeline statistics at a specific point in time /// /// This struct contains immutable statistics collected during pipeline execution. #[derive(Debug, Clone, Default, PartialEq)] pub struct PipelineStats { /// Total number of nodes processed pub nodes_processed: u64, /// Total number of nodes that resulted in error pub nodes_failed: u64, /// Total number of nodes persisted to storage pub nodes_stored: u64, /// Total number of transformations applied pub transformations_applied: u64, /// Token usage per model pub token_usage: HashMap, /// When the pipeline started started_at: Option, /// When the pipeline completed completed_at: Option, } impl PipelineStats { /// Creates a new empty `PipelineStats` #[must_use] pub fn new() -> Self { Self::default() } /// Returns the duration of the pipeline execution /// /// If the pipeline has not started, returns `None`. /// If the pipeline has started but not completed, returns the elapsed time since start. #[must_use] pub fn duration(&self) -> Option { match (self.started_at, self.completed_at) { (Some(start), Some(end)) => Some(end.duration_since(start)), (Some(start), None) => Some(start.elapsed()), _ => None, } } /// Calculates nodes processed per second /// /// Returns `None` if the pipeline hasn't started or if no nodes have been processed. #[must_use] pub fn nodes_per_second(&self) -> Option { let duration = self.duration()?; if duration.as_secs_f64() == 0.0 || self.nodes_processed == 0 { return None; } Some(u64_to_f64(self.nodes_processed) / duration.as_secs_f64()) } /// Returns the total number of tokens used across all models #[must_use] pub fn total_tokens(&self) -> u64 { self.token_usage.values().map(|u| u.total_tokens).sum() } /// Returns the total number of LLM requests made #[must_use] pub fn total_requests(&self) -> u64 { self.token_usage.values().map(|u| u.request_count).sum() } /// Returns the total prompt tokens across all models #[must_use] pub fn total_prompt_tokens(&self) -> u64 { self.token_usage.values().map(|u| u.prompt_tokens).sum() } /// Returns the total completion tokens across all models #[must_use] pub fn total_completion_tokens(&self) -> u64 { self.token_usage.values().map(|u| u.completion_tokens).sum() } } /// Thread-safe statistics collector for pipeline execution /// /// This collector uses atomic counters for lock-free updates and can be safely /// shared across multiple threads during pipeline processing. #[derive(Debug)] pub struct StatsCollector { nodes_processed: AtomicU64, nodes_failed: AtomicU64, nodes_stored: AtomicU64, transformations_applied: AtomicU64, token_usage: Mutex>, started_at: Mutex>, completed_at: Mutex>, } impl Default for StatsCollector { fn default() -> Self { Self::new() } } impl StatsCollector { /// Creates a new `StatsCollector` #[must_use] pub fn new() -> Self { Self { nodes_processed: AtomicU64::new(0), nodes_failed: AtomicU64::new(0), nodes_stored: AtomicU64::new(0), transformations_applied: AtomicU64::new(0), token_usage: Mutex::new(HashMap::new()), started_at: Mutex::new(None), completed_at: Mutex::new(None), } } /// Marks the pipeline as started pub fn start(&self) { let mut started = lock_recover(&self.started_at); *started = Some(Instant::now()); } /// Marks the pipeline as completed pub fn complete(&self) { let mut completed = lock_recover(&self.completed_at); *completed = Some(Instant::now()); } /// Increments the count of processed nodes pub fn increment_nodes_processed(&self, count: u64) { self.nodes_processed.fetch_add(count, Ordering::Relaxed); } /// Increments the count of failed nodes pub fn increment_nodes_failed(&self, count: u64) { self.nodes_failed.fetch_add(count, Ordering::Relaxed); } /// Increments the count of stored nodes pub fn increment_nodes_stored(&self, count: u64) { self.nodes_stored.fetch_add(count, Ordering::Relaxed); } /// Increments the count of applied transformations pub fn increment_transformations(&self, count: u64) { self.transformations_applied .fetch_add(count, Ordering::Relaxed); } /// Records token usage for a specific model /// /// This method is compatible with OpenTelemetry LLM specification. /// /// # Arguments /// /// * `model` - The name/identifier of the model /// * `prompt_tokens` - Number of tokens in the prompt /// * `completion_tokens` - Number of tokens in the completion pub fn record_token_usage( &self, model: impl AsRef, prompt_tokens: u64, completion_tokens: u64, ) { let mut usage = lock_recover(&self.token_usage); let model_usage = usage.entry(model.as_ref().to_string()).or_default(); model_usage.record(prompt_tokens, completion_tokens); } /// Returns a snapshot of the current statistics #[must_use] pub fn get_stats(&self) -> PipelineStats { PipelineStats { nodes_processed: self.nodes_processed.load(Ordering::Relaxed), nodes_failed: self.nodes_failed.load(Ordering::Relaxed), nodes_stored: self.nodes_stored.load(Ordering::Relaxed), transformations_applied: self.transformations_applied.load(Ordering::Relaxed), token_usage: lock_recover(&self.token_usage).clone(), started_at: *lock_recover(&self.started_at), completed_at: *lock_recover(&self.completed_at), } } } impl Clone for StatsCollector { fn clone(&self) -> Self { Self { nodes_processed: AtomicU64::new(self.nodes_processed.load(Ordering::Relaxed)), nodes_failed: AtomicU64::new(self.nodes_failed.load(Ordering::Relaxed)), nodes_stored: AtomicU64::new(self.nodes_stored.load(Ordering::Relaxed)), transformations_applied: AtomicU64::new( self.transformations_applied.load(Ordering::Relaxed), ), token_usage: Mutex::new(lock_recover(&self.token_usage).clone()), started_at: Mutex::new(*lock_recover(&self.started_at)), completed_at: Mutex::new(*lock_recover(&self.completed_at)), } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_stats_collector() { let collector = StatsCollector::new(); collector.start(); collector.increment_nodes_processed(10); collector.increment_nodes_failed(2); collector.increment_nodes_stored(8); collector.increment_transformations(15); collector.complete(); let stats = collector.get_stats(); assert_eq!(stats.nodes_processed, 10); assert_eq!(stats.nodes_failed, 2); assert_eq!(stats.nodes_stored, 8); assert_eq!(stats.transformations_applied, 15); assert!(stats.duration().is_some()); assert!(stats.nodes_per_second().is_some()); } #[test] fn test_model_usage() { let mut usage = ModelUsage::new(); usage.record(100, 50); usage.record(200, 100); assert_eq!(usage.prompt_tokens, 300); assert_eq!(usage.completion_tokens, 150); assert_eq!(usage.total_tokens, 450); assert_eq!(usage.request_count, 2); } #[test] fn test_record_token_usage() { let collector = StatsCollector::new(); collector.record_token_usage("gpt-4", 100, 50); collector.record_token_usage("gpt-4", 200, 100); collector.record_token_usage("gpt-3.5", 50, 25); let stats = collector.get_stats(); assert_eq!(stats.token_usage.len(), 2); let gpt4_usage = stats.token_usage.get("gpt-4").unwrap(); assert_eq!(gpt4_usage.prompt_tokens, 300); assert_eq!(gpt4_usage.completion_tokens, 150); assert_eq!(gpt4_usage.request_count, 2); assert_eq!(stats.total_tokens(), 525); assert_eq!(stats.total_requests(), 3); } #[test] fn test_empty_stats() { let stats = PipelineStats::new(); assert_eq!(stats.nodes_processed, 0); assert_eq!(stats.nodes_failed, 0); assert_eq!(stats.total_tokens(), 0); assert!(stats.duration().is_none()); assert!(stats.nodes_per_second().is_none()); } #[test] fn test_stats_collector_clone() { let collector = StatsCollector::new(); collector.increment_nodes_processed(5); collector.record_token_usage("model-1", 10, 5); let cloned = collector.clone(); // Modify original collector.increment_nodes_processed(3); // Cloned should have original value let cloned_stats = cloned.get_stats(); assert_eq!(cloned_stats.nodes_processed, 5); // Original should have updated value let original_stats = collector.get_stats(); assert_eq!(original_stats.nodes_processed, 8); } #[test] fn test_pipeline_stats_duration_while_running() { let collector = StatsCollector::new(); collector.start(); let stats = collector.get_stats(); // Should return Some while running assert!(stats.duration().is_some()); assert_eq!(stats.completed_at, None); } } ================================================ FILE: swiftide-core/src/stream_backoff.rs ================================================ // Credits go to https://github.com/ihrwein/backoff/pull/50 use std::{pin::Pin, task::Poll, time::Duration}; use backoff::{backoff::Backoff, future::Sleeper}; use futures_util::{Stream, TryStream}; use pin_project::pin_project; // /// Applies a [`Backoff`] policy to a [`Stream`] // /// // /// After any [`Err`] is emitted, the stream is paused for [`Backoff::next_backoff`]. The // /// [`Backoff`] is [`reset`](`Backoff::reset`) on any [`Ok`] value. // /// // /// If [`Backoff::next_backoff`] returns [`None`] then the backing stream is given up on, and // closed. pub fn backoff( // stream: S, // backoff: B, // ) -> StreamBackoff { // StreamBackoff::new(stream, backoff, TokioSleeper) // } pub(crate) struct TokioSleeper; impl Sleeper for TokioSleeper { type Sleep = ::tokio::time::Sleep; fn sleep(&self, dur: Duration) -> Self::Sleep { ::tokio::time::sleep(dur) } } /// See [`backoff`] #[pin_project] pub struct StreamBackoff { #[pin] stream: S, backoff: B, sleeper: Sl, #[pin] state: State, } #[pin_project(project = StateProj)] enum State { BackingOff { #[pin] backoff_sleep: Sl::Sleep, }, GivenUp, Awake, } impl StreamBackoff { pub fn new(stream: S, backoff: B, sleeper: Sl) -> Self { Self { stream, backoff, sleeper, state: State::Awake, } } } impl Stream for StreamBackoff where Sl::Sleep: Future, { type Item = Result; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { let mut this = self.project(); match this.state.as_mut().project() { StateProj::BackingOff { mut backoff_sleep } => match backoff_sleep.as_mut().poll(cx) { Poll::Ready(()) => { // tracing::debug!(deadline = ?backoff_sleep.deadline(), "Backoff complete, // waking up"); this.state.set(State::Awake); } Poll::Pending => { // let deadline = backoff_sleep.deadline(); // tracing::trace!( // ?deadline, // remaining_duration = ?deadline.saturating_duration_since(Instant::now()), // "Still waiting for backoff sleep to complete" // ); return Poll::Pending; } }, StateProj::GivenUp => { // tracing::debug!("Backoff has given up, stream is closed"); return Poll::Ready(None); } StateProj::Awake => {} } let next_item = this.stream.try_poll_next(cx); match &next_item { Poll::Ready(Some(Err(_))) => { if let Some(backoff_duration) = this.backoff.next_backoff() { let backoff_sleep = this.sleeper.sleep(backoff_duration); // tracing::debug!( // deadline = ?backoff_sleep.deadline(), // duration = ?backoff_duration, // "Error received, backing off" // ); this.state.set(State::BackingOff { backoff_sleep }); } else { // tracing::debug!("Error received, giving up"); this.state.set(State::GivenUp); } } Poll::Ready(_) => { // tracing::trace!("Non-error received, resetting backoff"); this.backoff.reset(); } Poll::Pending => {} } next_item } } // Tokio clock is required to be able to freeze time during marble tests #[cfg(test)] mod tests { use super::*; use futures_util::{StreamExt, pin_mut, poll, stream}; use std::{task::Poll, time::Duration}; use tokio::{self, sync::mpsc}; #[tokio::test] async fn stream_should_back_off() { tokio::time::pause(); let tick = Duration::from_secs(1); let rx = stream::iter([Ok(0), Ok(1), Err(2), Ok(3), Ok(4)]); let rx = StreamBackoff::new(rx, backoff::backoff::Constant::new(tick), TokioSleeper); pin_mut!(rx); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(0)))); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(1)))); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(2)))); assert_eq!(poll!(rx.next()), Poll::Pending); tokio::time::advance(tick * 2).await; assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(3)))); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(4)))); assert_eq!(poll!(rx.next()), Poll::Ready(None)); } #[tokio::test] async fn backoff_time_should_update() { tokio::time::pause(); let (tx, rx) = mpsc::unbounded_channel(); let rx = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); let rx = StreamBackoff::new(rx, LinearBackoff::new(Duration::from_secs(2)), TokioSleeper); pin_mut!(rx); tx.send(Ok(0)).unwrap(); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(0)))); tx.send(Ok(1)).unwrap(); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(1)))); tx.send(Err(2)).unwrap(); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(2)))); assert_eq!(poll!(rx.next()), Poll::Pending); tokio::time::advance(Duration::from_secs(3)).await; assert_eq!(poll!(rx.next()), Poll::Pending); tx.send(Err(3)).unwrap(); assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(3)))); tx.send(Ok(4)).unwrap(); assert_eq!(poll!(rx.next()), Poll::Pending); tokio::time::advance(Duration::from_secs(3)).await; assert_eq!(poll!(rx.next()), Poll::Pending); tokio::time::advance(Duration::from_secs(2)).await; assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(4)))); assert_eq!(poll!(rx.next()), Poll::Pending); drop(tx); assert_eq!(poll!(rx.next()), Poll::Ready(None)); } #[tokio::test] async fn backoff_should_close_when_requested() { assert_eq!( StreamBackoff::new( stream::iter([Ok(0), Ok(1), Err(2), Ok(3)]), backoff::backoff::Stop {}, TokioSleeper ) .collect::>() .await, vec![Ok(0), Ok(1), Err(2)] ); } /// Dynamic backoff policy that is still deterministic and testable struct LinearBackoff { interval: Duration, current_duration: Duration, } impl LinearBackoff { fn new(interval: Duration) -> Self { Self { interval, current_duration: Duration::ZERO, } } } impl Backoff for LinearBackoff { fn next_backoff(&mut self) -> Option { self.current_duration += self.interval; Some(self.current_duration) } fn reset(&mut self) { self.current_duration = Duration::ZERO; } } } ================================================ FILE: swiftide-core/src/test_utils.rs ================================================ #![allow(clippy::missing_panics_doc)] use std::fmt::Write as _; use std::sync::{Arc, Mutex}; use async_trait::async_trait; use crate::ChatCompletionStream; use crate::chat_completion::{ ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, errors::LanguageModelError, }; use anyhow::Result; use pretty_assertions::assert_eq; #[macro_export] macro_rules! assert_default_prompt_snapshot { ($node:expr, $($key:expr => $value:expr),*) => { #[tokio::test] async fn test_default_prompt() { let template = default_prompt(); let mut prompt = template.clone().with_node(&TextNode::new($node)); $( prompt = prompt.with_context_value($key, $value); )* insta::assert_snapshot!(prompt.render().unwrap()); } }; ($($key:expr => $value:expr),*) => { #[tokio::test] async fn test_default_prompt() { let template = default_prompt(); let mut prompt = template; $( prompt = prompt.with_context_value($key, $value); )* insta::assert_snapshot!(prompt.render().unwrap()); } }; } type Expectations = Arc< Mutex< Vec<( ChatCompletionRequest<'static>, Result, )>, >, >; #[derive(Clone)] pub struct MockChatCompletion { pub expectations: Expectations, pub received_expectations: Expectations, } impl Default for MockChatCompletion { fn default() -> Self { Self::new() } } impl MockChatCompletion { pub fn new() -> Self { Self { expectations: Arc::new(Mutex::new(Vec::new())), received_expectations: Arc::new(Mutex::new(Vec::new())), } } pub fn expect_complete( &self, request: ChatCompletionRequest<'static>, response: Result, ) { let mut mutex = self.expectations.lock().unwrap(); mutex.insert(0, (request, response)); } } #[async_trait] impl ChatCompletion for MockChatCompletion { async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result { let request = request.to_owned(); let (expected_request, response) = self.expectations.lock().unwrap().pop().unwrap_or_else(|| { panic!( "Received completion request, but no expectations are set\n {}", pretty_request(&request) ) }); assert_eq!( &expected_request, &request, "Unexpected request\n: {}\nRemaining expectations:\n{}", pretty_request(&request), pretty_expectation(&(expected_request.clone(), response)) + "---\n" + &self .expectations .lock() .unwrap() .iter() .map(pretty_expectation) .collect::>() .join("---\n") ); if let Ok(response) = response { self.received_expectations .lock() .unwrap() .push((expected_request, Ok(response.clone()))); tracing::debug!( "[MockChatCompletion] Received request:\n{}\nResponse:\n{}", pretty_request(&request), pretty_response(&response) ); Ok(response) } else { let err = response.unwrap_err(); self.received_expectations .lock() .unwrap() .push((expected_request, Err(anyhow::anyhow!(err.to_string())))); Err(LanguageModelError::PermanentError(err.into())) } } /// Fakes a stream, first it checks the expectations, then it streams the response /// instantly in small chunks async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { let response = match self.complete(request).await { Ok(response) => response, Err(err) => return err.into(), }; let (tx, rx) = tokio::sync::mpsc::unbounded_channel::< Result, >(); tokio::spawn(async move { let mut chunk_response = ChatCompletionResponse::builder() .maybe_tool_calls(response.tool_calls.clone()) .build() .unwrap(); for chunk in response.message().unwrap().split_whitespace() { tracing::debug!("[MockChatCompletion] Sending chunk: {chunk}"); let chunk_response = chunk_response.append_message_delta(Some(chunk)).clone(); let _ = tx.send(Ok(chunk_response)); tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; } }); Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) } } impl Drop for MockChatCompletion { fn drop(&mut self) { // We are still cloned, so do not check assertions yet if Arc::strong_count(&self.received_expectations) > 1 { return; } let Ok(expectations) = self.expectations.lock() else { return; }; let Ok(received) = self.received_expectations.lock() else { return; }; if expectations.is_empty() { let num_received = received.len(); tracing::debug!("[MockChatCompletion] All {num_received} expectations were met"); } else { let received = received .iter() .map(pretty_expectation) .collect::>() .join("---\n"); let pending = expectations .iter() .map(pretty_expectation) .collect::>() .join("---\n"); panic!( "[MockChatCompletion] Not all expectations were met\n received:\n{received}\n\npending:\n{pending}" ); } } } fn pretty_expectation( expectation: &( ChatCompletionRequest<'static>, Result, ), ) -> String { let mut output = String::new(); let request = &expectation.0; output.push_str("Request:\n"); output.push_str(&pretty_request(request)); output.push_str(" =>\n"); let response_result = &expectation.1; if let Ok(response) = response_result { output += &pretty_response(response); } output } fn pretty_request(request: &ChatCompletionRequest<'_>) -> String { let mut output = String::new(); for message in request.messages() { writeln!(output, " {message}").unwrap(); } output } fn pretty_response(response: &ChatCompletionResponse) -> String { let mut output = String::new(); if let Some(message) = response.message() { writeln!(output, " {message}").unwrap(); } if let Some(tool_calls) = response.tool_calls() { for tool_call in tool_calls { writeln!(output, " {tool_call}").unwrap(); } } output } ================================================ FILE: swiftide-core/src/token_estimation.rs ================================================ use std::borrow::Cow; use anyhow::Result; use async_trait::async_trait; use crate::{chat_completion::ChatMessage, prompt::Prompt}; /// Estimate the number of tokens in a given value. /// /// This trait is intentionally async so implementations can defer to remote or /// more expensive estimators without blocking. /// /// # Examples /// /// ```rust /// # use swiftide_core::token_estimation::{CharEstimator, EstimateTokens}; /// # use swiftide_core::chat_completion::ChatMessage; /// # #[tokio::main] /// # async fn main() -> anyhow::Result<()> { /// let estimator = CharEstimator; /// let message = ChatMessage::new_user("Hello from Swiftide!"); /// let tokens = estimator.estimate(&message).await?; /// assert!(tokens > 0); /// # Ok(()) /// # } /// ``` #[async_trait] pub trait EstimateTokens { async fn estimate(&self, value: impl Estimatable) -> Result; } /// A rough estimator when speed matters more than accuracy. /// /// Divides the number of characters by 4 as recommended by `OpenAI`. /// /// # Examples /// /// ```rust /// # use swiftide_core::token_estimation::{CharEstimator, EstimateTokens}; /// # #[tokio::main] /// # async fn main() -> anyhow::Result<()> { /// let estimator = CharEstimator; /// let tokens = estimator.estimate("Roughly four chars per token.").await?; /// assert!(tokens > 0); /// # Ok(()) /// # } /// ``` pub struct CharEstimator; #[async_trait] impl EstimateTokens for CharEstimator { async fn estimate(&self, value: impl Estimatable) -> Result { let s = value.for_estimate()?; Ok(s.iter().map(|s| s.chars().count()).sum::() / 4 + value.additional_tokens()) } } /// A value that can be estimated for the number of tokens it contains. /// /// # Errors /// /// Errors if the value cannot be presented for estimation. /// /// # Examples /// /// ```rust /// # use std::borrow::Cow; /// # use anyhow::Result; /// # use swiftide_core::token_estimation::Estimatable; /// struct Snippet { /// title: String, /// body: String, /// } /// /// impl Estimatable for Snippet { /// fn for_estimate(&self) -> Result>> { /// Ok(vec![Cow::Borrowed(&self.title), Cow::Borrowed(&self.body)]) /// } /// } /// ``` pub trait Estimatable: Send + Sync { /// A list of string slices used for estimation /// /// # Errors /// /// Some estimatable values may fail to render or prepare for estimation. fn for_estimate(&self) -> Result>>; /// Optionally return extra tokens that should be added to the estimate. fn additional_tokens(&self) -> usize { 0 } } impl Estimatable for &str { fn for_estimate(&self) -> Result>> { Ok(vec![Cow::Borrowed(self)]) } } impl Estimatable for String { fn for_estimate(&self) -> Result>> { Ok(vec![Cow::Borrowed(self.as_str())]) } } impl Estimatable for &Prompt { fn for_estimate(&self) -> Result>> { let rendered = self.render()?; Ok(vec![Cow::Owned(rendered)]) } } impl Estimatable for &ChatMessage { fn for_estimate(&self) -> Result>> { Ok(match self { ChatMessage::User(msg) | ChatMessage::Summary(msg) | ChatMessage::System(msg) => { vec![Cow::Borrowed(msg)] } ChatMessage::UserWithParts(parts) => parts .iter() .filter_map(|part| match part { crate::chat_completion::ChatMessageContentPart::Text { text } => { Some(Cow::Borrowed(text.as_ref())) } crate::chat_completion::ChatMessageContentPart::Image { .. } | crate::chat_completion::ChatMessageContentPart::Document { .. } | crate::chat_completion::ChatMessageContentPart::Audio { .. } | crate::chat_completion::ChatMessageContentPart::Video { .. } => None, }) .collect(), ChatMessage::Assistant(msg, vec) => { // Note that this is not super accurate. // // It's a bit verbose to avoid unnecessary allocations. Is what it is. let mut tool_calls = vec.as_ref().map(|vec| { vec.iter() .filter_map(|c| c.args().map(Cow::Borrowed)) .collect::>() }); if let Some(msg) = msg { if let Some(tool_calls) = tool_calls.as_mut() { let mut msg = vec![Cow::Borrowed(msg.as_ref())]; msg.append(tool_calls); msg } else { vec![Cow::Borrowed(msg)] } } else if let Some(tool_calls) = tool_calls { tool_calls } else { vec!["None".into()] } } ChatMessage::ToolOutput(_tool_call, tool_output) => { let tool_output_content = tool_output.content().unwrap_or_default(); vec![Cow::Borrowed(tool_output_content)] } ChatMessage::Reasoning(_reasoning_item) => vec![], }) } // 4 each for the role // // See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb fn additional_tokens(&self) -> usize { 4 } } impl Estimatable for &[ChatMessage] { fn for_estimate(&self) -> Result>> { let mut total = Vec::new(); for msg in *self { let mut v = msg .for_estimate()? .into_iter() .map(Cow::into_owned) .map(Into::into) .collect(); total.append(&mut v); } Ok(total) } // Apparently every reply is primed with a <|start|>assistant<|message|> fn additional_tokens(&self) -> usize { self.iter().map(|m| m.additional_tokens()).sum::() + 3 } } #[cfg(test)] mod tests { use super::*; use crate::chat_completion::ToolCall; #[tokio::test] async fn estimate_counts_characters_and_additional_tokens() { let estimator = CharEstimator; let tokens = estimator.estimate("abcd").await.unwrap(); assert_eq!(tokens, 1); } #[tokio::test] async fn estimate_prompt_renders_before_counting() { let estimator = CharEstimator; let prompt = Prompt::from("hello {{name}}").with_context_value("name", "swiftide"); let tokens = estimator.estimate(&prompt).await.unwrap(); assert_eq!(tokens, "hello swiftide".chars().count() / 4); } #[tokio::test] async fn estimate_chat_message_includes_role_tokens() { let estimator = CharEstimator; let message = ChatMessage::new_user("hello"); let tokens = estimator.estimate(&message).await.unwrap(); assert_eq!(tokens, "hello".chars().count() / 4 + 4); } #[tokio::test] async fn estimate_slice_adds_reply_priming_tokens() { let estimator = CharEstimator; let messages = [ ChatMessage::new_user("hello"), ChatMessage::new_system("world"), ]; let tokens = estimator.estimate(&messages[..]).await.unwrap(); let content_tokens = "helloworld".chars().count() / 4; let additional_tokens = 4 + 4 + 3; assert_eq!(tokens, content_tokens + additional_tokens); } #[tokio::test] async fn assistant_tool_calls_are_included_in_estimate() { let estimator = CharEstimator; let tool_call = ToolCall::builder() .id("tool-1") .name("search") .args("{\"q\":\"swiftide\"}") .build() .unwrap(); let message = ChatMessage::new_assistant(None::, Some(vec![tool_call])); let tokens = estimator.estimate(&message).await.unwrap(); let content_tokens = "{\"q\":\"swiftide\"}".chars().count() / 4; assert_eq!(tokens, content_tokens + 4); } #[tokio::test] async fn assistant_without_content_or_tools_uses_none_marker() { let message = ChatMessage::Assistant(None, None); let message_ref = &message; let content = message_ref.for_estimate().unwrap(); assert_eq!(content, vec![Cow::Borrowed("None")]); } } ================================================ FILE: swiftide-core/src/type_aliases.rs ================================================ #![cfg_attr(coverage_nightly, coverage(off))] use serde::{Deserialize, Serialize}; pub type Embedding = Vec; pub type Embeddings = Vec; #[derive(Serialize, Deserialize, Clone, PartialEq)] pub struct SparseEmbedding { pub indices: Vec, pub values: Vec, } pub type SparseEmbeddings = Vec; impl std::fmt::Debug for SparseEmbedding { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SparseEmbedding") .field("indices", &self.indices.len()) .field("values", &self.values.len()) .finish() } } ================================================ FILE: swiftide-core/src/util.rs ================================================ //! Utility functions for Swiftide /// Safely truncates a string to a maximum number of characters. /// /// Respects utf8 character boundaries. pub fn safe_truncate_utf8(s: impl AsRef, max_chars: usize) -> String { s.as_ref().chars().take(max_chars).collect() } /// Debug print a long string by truncating to n characters /// /// Enabled with the `truncate-debug` feature flag, which is enabled by default. /// /// If debugging large outputs is needed, set `swiftide_core` to `no-default-features` /// /// # Example /// /// ```ignore /// # use swiftide_core::util::debug_long_utf8; /// let s = debug_long_utf8("🦀".repeat(10), 3); /// /// assert_eq!(s, "🦀🦀🦀 (10)"); /// ``` pub fn debug_long_utf8(s: impl AsRef, max_chars: usize) -> String { if cfg!(feature = "truncate-debug") { let trunc = safe_truncate_utf8(&s, max_chars); format!("{} ({})", trunc, s.as_ref().chars().count()) } else { s.as_ref().into() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_safe_truncate_str_with_utf8_char_boundary() { let s = "🦀".repeat(101); // Single char assert_eq!(safe_truncate_utf8(&s, 100).chars().count(), 100); // With invalid char boundary let s = "Jürgen".repeat(100); assert_eq!(safe_truncate_utf8(&s, 100).chars().count(), 100); } } ================================================ FILE: swiftide-indexing/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-indexing" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] swiftide-core = { path = "../swiftide-core", version = "0.32" } swiftide-macros = { path = "../swiftide-macros", version = "0.32" } anyhow = { workspace = true } async-trait = { workspace = true } derive_builder = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true, features = ["full"] } tokio-stream = { workspace = true } num_cpus = { workspace = true } tracing = { workspace = true } itertools = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } indoc = { workspace = true } ignore = { workspace = true } text-splitter = { workspace = true, features = ["markdown"] } fs-err.workspace = true [dev-dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } test-log = { workspace = true } mockall = { workspace = true } insta = { workspace = true } test-case = { workspace = true } temp-dir = { workspace = true } [features] # TODO: Should not depend on integrations, transformers that use them should be in integrations instead and re-exported from root for convencience tree-sitter = [] [lints] workspace = true [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: swiftide-indexing/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] pub mod loaders; pub mod persist; pub mod transformers; mod pipeline; pub use pipeline::Pipeline; ================================================ FILE: swiftide-indexing/src/loaders/file_loader.rs ================================================ //! Load files from a directory use std::{ io::Read as _, path::{Path, PathBuf}, }; use anyhow::Context as _; use ignore::{DirEntry, Walk}; use swiftide_core::{Loader, indexing::IndexingStream, indexing::TextNode}; use tracing::{Span, debug_span, instrument}; /// The `FileLoader` struct is responsible for loading files from a specified directory, filtering /// them based on their extensions, and creating a stream of these files for further processing. /// /// # Example /// /// Create a pipeline that loads the current directory and indexes all files with the ".rs" /// /// ```no_run /// # use swiftide_indexing as indexing; /// # use swiftide_indexing::loaders::FileLoader; /// indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])); /// ``` #[derive(Clone, Debug)] pub struct FileLoader { pub(crate) root: PathBuf, pub(crate) extensions: Option>, } impl FileLoader { /// Creates a new `FileLoader` with the specified path. /// /// # Arguments /// /// - `root`: The root directory to load files from. /// /// # Returns /// /// A new instance of `FileLoader`. pub fn new(root: impl AsRef) -> Self { Self { root: root.as_ref().to_path_buf(), extensions: None, } } /// Adds extensions to the loader. /// /// # Arguments /// /// - `extensions`: A list of extensions to add without the leading dot. /// /// # Returns /// /// The `FileLoader` instance with the added extensions. #[must_use] pub fn with_extensions(mut self, extensions: &[impl AsRef]) -> Self { let existing = self.extensions.get_or_insert_default(); existing.extend(extensions.iter().map(|ext| ext.as_ref().to_string())); self } /// Lists the nodes (files) that match the specified extensions. /// /// # Returns /// /// A vector of `TextNode` representing the matching files. /// /// # Panics /// /// This method will panic if it fails to read a file's content. pub fn list_nodes(&self) -> Vec { self.iter().filter_map(Result::ok).collect() } /// Iterates over the files in the directory pub fn iter(&self) -> impl Iterator> + use<> { Iter::new(&self.root, self.extensions.clone()).fuse() } } /// An iterator that walks over the files in a directory and loads them. /// /// This is a private struct that is used to implement the `FileLoader` iterator. struct Iter { /// The walk instance that iterates over the files in the directory. walk: Walk, /// The extensions to include. include_extensions: Option>, /// A span that tracks the current file loader. span: Span, } impl Iterator for Iter { type Item = anyhow::Result; fn next(&mut self) -> Option { let _span = self.span.enter(); loop { // stop the iteration if there are no more entries let entry = self.walk.next()?; // propagate any errors that occur during the directory traversal let entry = match entry { Ok(entry) => entry, Err(err) => return Some(Err(err.into())), }; if let Some(node) = self.load(&entry) { return Some(node); } } } } impl Iter { /// Creates a new `Iter` instance. fn new(root: &Path, include_extensions: Option>) -> Self { let span = debug_span!("file_loader", root = %root.display()); tracing::debug!(parent: &span, extensions = ?include_extensions, "Loading files"); Self { walk: Walk::new(root), include_extensions, span, } } #[instrument(skip_all, fields(path = %entry.path().display()))] fn load(&self, entry: &DirEntry) -> Option> { if entry.file_type().is_some_and(|ft| !ft.is_file()) { // Skip directories and non-files return None; } if let Some(extensions) = &self.include_extensions { let Some(extension) = entry.path().extension() else { tracing::trace!("Skipping file without extension"); return None; }; let extension = extension.to_string_lossy(); if !extensions.iter().any(|ext| ext == &extension) { tracing::trace!("Skipping file with extension {extension}"); return None; } } tracing::debug!("Loading file"); match read_node(entry) { Ok(node) => { tracing::debug!(node_id = %node.id(), "Loaded file"); Some(Ok(node)) } Err(err) => { tracing::error!(error = %err, "Failed to load file"); Some(Err(err)) } } } } fn read_node(entry: &DirEntry) -> anyhow::Result { // Files might be invalid utf-8, so we need to read them as bytes and convert it lossy, as // Swiftide (currently) works internally with strings. let mut file = fs_err::File::open(entry.path()).context("Failed to open file")?; let mut buf = vec![]; file.read_to_end(&mut buf).context("Failed to read file")?; let content = String::from_utf8_lossy(&buf); let original_size = content.len(); TextNode::builder() .path(entry.path()) .chunk(content) .original_size(original_size) .build() } impl Loader for FileLoader { type Output = String; /// Converts the `FileLoader` into a stream of `TextNode`. /// /// # Returns /// /// An `IndexingStream` representing the stream of files. /// /// # Errors /// This method will return an error if it fails to read a file's content. fn into_stream(self) -> IndexingStream { IndexingStream::iter(self.iter()) } fn into_stream_boxed(self: Box) -> IndexingStream { self.into_stream() } } #[cfg(test)] mod test { use tokio_stream::StreamExt as _; use super::*; #[test] fn test_with_extensions() { let loader = FileLoader::new("/tmp").with_extensions(&["rs"]); assert_eq!(loader.extensions, Some(vec!["rs".to_string()])); } #[tokio::test] async fn test_ignores_invalid_utf8() { let tempdir = temp_dir::TempDir::new().unwrap(); fs_err::write(tempdir.child("invalid.txt"), [0x80, 0x80, 0x80]).unwrap(); let loader = FileLoader::new(tempdir.path()).with_extensions(&["txt"]); let result = loader.into_stream().collect::>().await; assert_eq!(result.len(), 1); let first = result.first().unwrap(); assert_eq!(first.as_ref().unwrap().chunk, "���".to_string()); } } ================================================ FILE: swiftide-indexing/src/loaders/mod.rs ================================================ //! The `loaders` module provides functionality for loading files from a specified directory. //! It includes the `FileLoader` struct which is used to filter and stream files based on their //! extensions. //! //! This module is a part of the Swiftide project, designed for asynchronous file indexing and //! processing. The `FileLoader` struct is re-exported for ease of use in other parts of the //! project. pub mod file_loader; pub use file_loader::FileLoader; ================================================ FILE: swiftide-indexing/src/persist/memory_storage.rs ================================================ use std::{ collections::HashMap, sync::{ Arc, atomic::{AtomicUsize, Ordering}, }, }; use anyhow::Result; use async_trait::async_trait; use derive_builder::Builder; use tokio::sync::RwLock; use swiftide_core::{ Persist, indexing::{Chunk, IndexingStream, Node}, }; #[derive(Debug, Default, Builder, Clone)] #[builder(pattern = "owned")] /// A simple in-memory storage implementation. /// /// Great for experimentation and testing. /// /// The storage will use a zero indexed, incremental counter as the key for each node if the node id /// is not set. pub struct MemoryStorage { data: Arc>>>, #[builder(default)] batch_size: Option, #[builder(default = Arc::new(AtomicUsize::new(0)))] node_count: Arc, } impl MemoryStorage { fn key(&self) -> String { self.node_count.fetch_add(1, Ordering::Relaxed).to_string() } /// Retrieve a node by its key pub async fn get(&self, key: impl AsRef) -> Option> { self.data.read().await.get(key.as_ref()).cloned() } /// Retrieve all nodes in the storage pub async fn get_all_values(&self) -> Vec> { self.data.read().await.values().cloned().collect() } /// Retrieve all nodes in the storage with their keys pub async fn get_all(&self) -> Vec<(String, Node)> { self.data .read() .await .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect() } } #[async_trait] impl Persist for MemoryStorage { type Input = T; type Output = T; async fn setup(&self) -> Result<()> { Ok(()) } /// Store a node by its id /// /// If the node does not have an id, a simple counter is used as the key. async fn store(&self, node: Node) -> Result> { self.data.write().await.insert(self.key(), node.clone()); Ok(node) } /// Store multiple nodes at once /// /// If a node does not have an id, a simple counter is used as the key. async fn batch_store(&self, nodes: Vec>) -> IndexingStream { let mut lock = self.data.write().await; for node in &nodes { lock.insert(self.key(), node.clone()); } IndexingStream::iter(nodes.into_iter().map(Ok)) } fn batch_size(&self) -> Option { self.batch_size } } #[cfg(test)] mod test { use super::*; use futures_util::TryStreamExt; use swiftide_core::indexing::TextNode; #[tokio::test] async fn test_memory_storage() { let storage = MemoryStorage::default(); let node = TextNode::default(); let node = storage.store(node.clone()).await.unwrap(); assert_eq!(storage.get("0").await, Some(node)); } #[tokio::test] async fn test_inserting_multiple_nodes() { let storage = MemoryStorage::default(); let node1 = TextNode::default(); let node2 = TextNode::default(); storage.store(node1.clone()).await.unwrap(); storage.store(node2.clone()).await.unwrap(); dbg!(storage.get_all().await); assert_eq!(storage.get("0").await, Some(node1)); assert_eq!(storage.get("1").await, Some(node2)); } #[tokio::test] async fn test_batch_store() { let storage = MemoryStorage::default(); let node1 = TextNode::default(); let node2 = TextNode::default(); let stream = storage .batch_store(vec![node1.clone(), node2.clone()]) .await; let result: Vec = stream.try_collect().await.unwrap(); assert_eq!(result.len(), 2); assert_eq!(result[0], node1); assert_eq!(result[1], node2); } } ================================================ FILE: swiftide-indexing/src/persist/mod.rs ================================================ //! Storage implementations for persisting data //! //! More storage implementations are available as integrations. mod memory_storage; pub use memory_storage::MemoryStorage; ================================================ FILE: swiftide-indexing/src/pipeline.rs ================================================ use anyhow::Result; use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; use swiftide_core::{ BatchableTransformer, ChunkerTransformer, Loader, NodeCache, Persist, SimplePrompt, Transformer, WithBatchIndexingDefaults, WithIndexingDefaults, indexing::{Chunk, IndexingDefaults}, statistics::StatsCollector, }; use tokio::{ sync::{Mutex, mpsc}, task, }; use tracing::Instrument; use std::{pin::Pin, sync::Arc, time::Duration}; use swiftide_core::indexing::{EmbedMode, IndexingStream, Node}; macro_rules! trace_span { ($op:literal, $step:expr) => { tracing::trace_span!($op, "otel.name" = format!("{}.{}", $op, $step.name()),) }; ($op:literal) => { tracing::trace_span!($op, "otel.name" = format!("{}", $op),) }; } macro_rules! node_trace_log { ($step:expr, $node:expr, $msg:literal) => { tracing::trace!( node = ?$node, node_id = ?$node.id(), step = $step.name(), $msg ) }; } macro_rules! batch_node_trace_log { ($step:expr, $nodes:expr, $msg:literal) => { tracing::trace!(batch_size = $nodes.len(), nodes = ?$nodes, step = $step.name(), $msg) }; } macro_rules! pipeline_with_new_stream { ($pipeline:expr, $stream:expr) => { Pipeline { stream: $stream.into(), storage_setup_fns: $pipeline.storage_setup_fns.clone(), concurrency: $pipeline.concurrency, indexing_defaults: $pipeline.indexing_defaults.clone(), batch_size: $pipeline.batch_size, stats: $pipeline.stats.clone(), } }; } /// The default batch size for batch processing. const DEFAULT_BATCH_SIZE: usize = 256; /// A pipeline for indexing files, adding metadata, chunking, transforming, embedding, and then /// storing them. /// /// The `Pipeline` struct orchestrates the entire file indexing process. It is designed to be /// flexible and performant, allowing for various stages of data transformation and storage to be /// configured and executed asynchronously. /// /// # Fields /// /// * `stream` - The stream of `Node` items to be processed. /// * `storage` - Optional storage backend where the processed nodes will be stored. /// * `concurrency` - The level of concurrency for processing nodes. /// * `stats` - Statistics collector for monitoring pipeline execution. pub struct Pipeline { stream: IndexingStream, // storage: Vec>>, storage_setup_fns: Vec, concurrency: usize, indexing_defaults: IndexingDefaults, batch_size: usize, stats: StatsCollector, } type DynStorageSetupFn = Arc Pin> + Send>> + Send + Sync>; impl Default for Pipeline { /// Creates a default `Pipeline` with an empty stream, no storage, and a concurrency level equal /// to the number of CPUs. fn default() -> Self { Self { stream: IndexingStream::::empty(), storage_setup_fns: Vec::new(), concurrency: num_cpus::get(), indexing_defaults: IndexingDefaults::default(), batch_size: DEFAULT_BATCH_SIZE, stats: StatsCollector::new(), } } } impl Pipeline { /// Creates a `Pipeline` from a given loader. /// /// # Arguments /// /// * `loader` - A loader that implements the `Loader` trait. /// /// # Returns /// /// An instance of `Pipeline` initialized with the provided loader. pub fn from_loader(loader: impl Loader + 'static) -> Self { let stream = loader.into_stream(); Self { stream, ..Default::default() } } /// Sets the default LLM client to be used for LLM prompts for all transformers in the /// pipeline. #[must_use] pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self { self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client)); self } /// Creates a `Pipeline` from a given stream. /// /// # Arguments /// /// * `stream` - An `IndexingStream` containing the nodes to be processed. /// /// # Returns /// /// An instance of `Pipeline` initialized with the provided stream. pub fn from_stream(stream: impl Into>) -> Self { Self { stream: stream.into(), ..Default::default() } } /// Sets the concurrency level for the pipeline. By default the concurrency is set to the /// number of cpus. /// /// # Arguments /// /// * `concurrency` - The desired level of concurrency. /// /// # Returns /// /// An instance of `Pipeline` with the updated concurrency level. #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = concurrency; self } /// Sets the embed mode for the pipeline. The embed mode controls what (combination) fields of a /// [`Node`] be embedded with a vector when transforming with [`crate::transformers::Embed`] /// /// See also [`swiftide_core::indexing::EmbedMode`]. /// /// # Arguments /// /// * `embed_mode` - The desired embed mode. /// /// # Returns /// /// An instance of `Pipeline` with the updated embed mode. #[must_use] pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self { self.stream = self .stream .map_ok(move |mut node| { node.embed_mode = embed_mode; node }) .boxed() .into(); self } /// Filters out cached nodes using the provided cache. /// /// # Arguments /// /// * `cache` - A cache that implements the `NodeCache` trait. /// /// # Returns /// /// An instance of `Pipeline` with the updated stream that filters out cached nodes. #[must_use] pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self { let cache = Arc::new(cache); self.stream = self .stream .try_filter_map(move |node| { let cache = Arc::clone(&cache); let span = trace_span!("filter_cached", cache); async move { if cache.get(&node).await { node_trace_log!(cache, node, "node in cache, skipping"); Ok(None) } else { node_trace_log!(cache, node, "node not in cache, processing"); cache.set(&node).await; Ok(Some(node)) } } .instrument(span.or_current()) }) .boxed() .into(); self } /// Adds a transformer to the pipeline. /// /// Closures can also be provided as transformers. /// /// # Arguments /// /// * `transformer` - A transformer that implements the `Transformer` trait. /// /// # Returns /// /// An instance of `Pipeline` with the updated stream that applies the transformer to each node. #[must_use] pub fn then( self, mut transformer: impl Transformer + WithIndexingDefaults + 'static, ) -> Pipeline { let concurrency = transformer.concurrency().unwrap_or(self.concurrency); transformer.with_indexing_defaults(self.indexing_defaults.clone()); let transformer = Arc::new(transformer); let stream = self .stream .map_ok(move |node| { let transformer = transformer.clone(); let span = trace_span!("then", transformer); task::spawn( async move { node_trace_log!(transformer, node, "Transforming node"); transformer.transform_node(node).await } .instrument(span.or_current()), ) .err_into::() }) .try_buffer_unordered(concurrency) .map(|x| x.and_then(|x| x)); pipeline_with_new_stream!(self, stream.boxed()) } /// Adds a batch transformer to the pipeline. /// /// If the transformer has a batch size set, the batch size from the transformer is used, /// otherwise the pipeline default batch size ([`DEFAULT_BATCH_SIZE`]). /// /// # Arguments /// /// * `transformer` - A transformer that implements the `BatchableTransformer` trait. /// /// # Returns /// /// An instance of `Pipeline` with the updated stream that applies the batch transformer to each /// batch of nodes. #[must_use] pub fn then_in_batch( self, mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static, ) -> Pipeline { let concurrency = transformer.concurrency().unwrap_or(self.concurrency); transformer.with_indexing_defaults(self.indexing_defaults.clone()); let transformer = Arc::new(transformer); let stream = self .stream .try_chunks(transformer.batch_size().unwrap_or(self.batch_size)) .map_ok(move |nodes| { let transformer = Arc::clone(&transformer); let span = trace_span!("then_in_batch", transformer); tokio::spawn( async move { batch_node_trace_log!(transformer, nodes, "batch transforming nodes"); transformer.batch_transform(nodes).await } .instrument(span.or_current()), ) .map_err(anyhow::Error::from) }) .err_into::() .try_buffer_unordered(concurrency) // First get the streams from each future .try_flatten_unordered(None) // Then flatten the streams into a single stream .boxed(); pipeline_with_new_stream!(self, stream) } /// Adds a chunker transformer to the pipeline. /// /// # Arguments /// /// * `chunker` - A transformer that implements the `ChunkerTransformer` trait. /// /// # Returns /// /// An instance of `Pipeline` with the updated stream that applies the chunker transformer to /// each node. #[must_use] pub fn then_chunk( self, chunker: impl ChunkerTransformer + 'static, ) -> Pipeline { let chunker = Arc::new(chunker); let concurrency = chunker.concurrency().unwrap_or(self.concurrency); let stream = self .stream .map_ok(move |node| { let chunker = Arc::clone(&chunker); let span = trace_span!("then_chunk", chunker); tokio::spawn( async move { node_trace_log!(chunker, node, "Chunking node"); chunker.transform_node(node).await } .instrument(span.or_current()), ) .map_err(anyhow::Error::from) }) .err_into::() .try_buffer_unordered(concurrency) .try_flatten_unordered(None); pipeline_with_new_stream!(self, stream.boxed()) } /// Transforms and expands a single node into many nodes /// /// Sementacially identical to `then_chunk` and repurposes the `ChunkerTransformer` trait. /// /// The real difference is in communicating intent and the trace/span names. /// /// # Arguments /// /// * `transformer` - A transformer that implements the `ChunkerTransformer` trait. /// /// # Returns /// /// An instance of `Pipeline` with the updated stream that applies the chunker transformer to /// each node. #[must_use] pub fn then_expand( self, transformer: impl ChunkerTransformer + 'static, ) -> Pipeline { let chunker = Arc::new(transformer); let concurrency = chunker.concurrency().unwrap_or(self.concurrency); let stream = self .stream .map_ok(move |node| { let chunker = Arc::clone(&chunker); let span = trace_span!("then_expand", chunker); tokio::spawn( async move { node_trace_log!(chunker, node, "Expanding node"); chunker.transform_node(node).await } .instrument(span.or_current()), ) .map_err(anyhow::Error::from) }) .err_into::() .try_buffer_unordered(concurrency) .try_flatten_unordered(None); pipeline_with_new_stream!(self, stream.boxed()) } /// Persists indexing nodes using the provided storage backend. /// /// # Arguments /// /// * `storage` - A storage backend that implements the `Storage` trait. /// /// # Returns /// /// An instance of `Pipeline` with the configured storage backend. /// /// # Panics /// /// Panics if batch size turns out to be not set and batch storage is still invoked. /// Pipeline only invokes batch storing if the batch size is set, so should be alright. #[must_use] pub fn then_store_with( mut self, storage: impl Persist + 'static, ) -> Pipeline { let storage = Arc::new(storage); let storage_closure = storage.clone(); // Ensure we run the setup function only once. let completed = Arc::new(Mutex::new(false)); let setup_fn: DynStorageSetupFn = Arc::new(move || { let completed = Arc::clone(&completed); let storage_closure = Arc::clone(&storage_closure); Box::pin(async move { let mut lock = completed.lock().await; tracing::trace!(?storage_closure, "Setting up storage"); storage_closure.setup().await?; *lock = true; Ok(()) }) }); self.storage_setup_fns.push(setup_fn); // add storage to the stream instead of doing it at the end let stream = if storage.batch_size().is_some() { self.stream .try_chunks(storage.batch_size().unwrap()) .map_ok(move |nodes| { let storage = Arc::clone(&storage); let span = trace_span!("then_store_with_batched", storage); tokio::spawn( async move { batch_node_trace_log!(storage, nodes, "batch storing nodes"); storage.batch_store(nodes).await } .instrument(span.or_current()), ) .map_err(anyhow::Error::from) }) .err_into::() .try_buffer_unordered(self.concurrency) .try_flatten_unordered(None) .boxed() } else { self.stream .map_ok(move |node| { let storage = Arc::clone(&storage); let span = trace_span!("then_store_with", storage); tokio::spawn( async move { node_trace_log!(storage, node, "Storing node"); storage.store(node).await } .instrument(span.or_current()), ) .err_into::() }) .try_buffer_unordered(self.concurrency) .map(|x| x.and_then(|x| x)) .boxed() }; pipeline_with_new_stream!(self, stream) } /// Splits the stream into two streams based on a predicate. /// /// Note that this is not lazy. It will start consuming the stream immediately /// and send each item to the left or right stream based on the predicate. /// /// The other streams have a buffer, but should be started as soon as possible. /// The channels of the resulting streams are bounded and the parent stream will panic /// if sending fails. /// /// They can either be run concurrently, alternated between or merged back together. /// /// # Panics /// /// Panics if the receiving pipelines buffers are full or unavailable. #[must_use] pub fn split_by

(self, predicate: P) -> (Self, Self) where P: Fn(&Result>) -> bool + Send + Sync + 'static, { let predicate = Arc::new(predicate); let (left_tx, left_rx) = mpsc::channel(1000); let (right_tx, right_rx) = mpsc::channel(1000); let stream = self.stream; let span = trace_span!("split_by"); tokio::spawn( async move { stream .for_each_concurrent(self.concurrency, move |item| { let predicate = Arc::clone(&predicate); let left_tx = left_tx.clone(); let right_tx = right_tx.clone(); async move { if predicate(&item) { tracing::trace!(?item, "Sending to left stream"); left_tx .send(item) .await .expect("Failed to send to left stream"); } else { tracing::trace!(?item, "Sending to right stream"); right_tx .send(item) .await .expect("Failed to send to right stream"); } } }) .await; } .instrument(span.or_current()), ); let left_pipeline = pipeline_with_new_stream!(self, left_rx); let right_pipeline = pipeline_with_new_stream!(self, right_rx); (left_pipeline, right_pipeline) } /// Merges two streams into one /// /// This is useful for merging two streams that have been split using the `split_by` method. /// /// The full stream can then be processed using the `run` method. #[must_use] pub fn merge(self, other: Self) -> Self { let stream = tokio_stream::StreamExt::merge(self.stream, other.stream); Self { stream: stream.boxed().into(), ..self } } /// Throttles the stream of nodes, limiting the rate to 1 per duration. /// /// Useful for rate limiting the indexing pipeline. Uses `tokio_stream::StreamExt::throttle` /// internally which has a granularity of 1ms. #[must_use] pub fn throttle(mut self, duration: impl Into) -> Self { self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into()) .boxed() .into(); self } // Silently filters out errors encountered by the pipeline. // // This method filters out errors encountered by the pipeline, preventing them from bubbling up // and terminating the stream. Note that errors are not logged. #[must_use] pub fn filter_errors(mut self) -> Self { self.stream = self .stream .filter_map(|result| async { match result { Ok(node) => Some(Ok(node)), Err(_e) => None, } }) .boxed() .into(); self } /// Provide a closure to selectively filter nodes or errors /// /// This allows you to skip specific errors or nodes, or do ad hoc inspection. /// /// If the closure returns true, the result is kept, otherwise it is skipped. #[must_use] pub fn filter(mut self, filter: F) -> Self where F: Fn(&Result>) -> bool + Send + Sync + 'static, { self.stream = self .stream .filter(move |result| { let will_retain = filter(result); async move { will_retain } }) .boxed() .into(); self } /// Logs all results processed by the pipeline. /// /// This method logs all results processed by the pipeline at the `DEBUG` level. #[must_use] pub fn log_all(self) -> Self { self.log_errors().log_nodes() } /// Returns a snapshot of the current pipeline statistics /// /// This method provides real-time access to pipeline statistics during and after /// execution. The returned statistics include node counts, token usage, and timing /// information. /// /// # Example /// /// ```rust,ignore /// let pipeline = Pipeline::from_loader(loader).then(transformer); /// /// // During or after execution /// let stats = pipeline.stats(); /// println!("Processed {} nodes", stats.nodes_processed); /// ``` #[must_use] pub fn stats(&self) -> swiftide_core::statistics::PipelineStats { self.stats.get_stats() } /// Returns a reference to the statistics collector /// /// This provides direct access to the `StatsCollector` for recording additional /// metrics or for use by transformers that need to report their own statistics. #[must_use] pub fn stats_collector(&self) -> &StatsCollector { &self.stats } /// Logs all errors encountered by the pipeline. /// /// This method logs all errors encountered by the pipeline at the `ERROR` level. #[must_use] pub fn log_errors(mut self) -> Self { self.stream = self .stream .inspect_err(|e| tracing::error!(?e, "Error processing node")) .boxed() .into(); self } /// Logs all nodes processed by the pipeline. /// /// This method logs all nodes processed by the pipeline at the `DEBUG` level. #[must_use] pub fn log_nodes(mut self) -> Self { self.stream = self .stream .inspect_ok(|node| tracing::debug!(?node, "Processed node: {:?}", node)) .boxed() .into(); self } /// Runs the indexing pipeline. /// /// This method processes the stream of nodes, applying all configured transformations and /// storing the results. /// /// # Returns /// /// A `Result` indicating the success or failure of the pipeline execution. /// /// # Errors /// /// Returns an error if no storage backend is configured or if any stage of the pipeline fails. #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")] pub async fn run(mut self) -> Result<()> { self.stats.start(); tracing::info!( "Starting indexing pipeline with {} concurrency", self.concurrency ); // Ensure all storage backends are set up before processing nodes let setup_futures = self .storage_setup_fns .into_iter() .map(|func| async move { func().await }) .collect::>(); futures_util::future::try_join_all(setup_futures).await?; let mut total_nodes = 0u64; while let Some(_result) = self.stream.try_next().await? { total_nodes += 1; // Count successful nodes as stored (nodes that reach the end of the stream) self.stats.increment_nodes_stored(1); } self.stats.increment_nodes_processed(total_nodes); self.stats.complete(); let stats = self.stats.get_stats(); let elapsed = stats.duration(); if let Some(duration) = elapsed { let elapsed_secs = duration.as_secs_f64(); let nodes_per_sec = stats.nodes_per_second().unwrap_or(0.0); tracing::info!( nodes_processed = total_nodes, nodes_stored = stats.nodes_stored, total_tokens = stats.total_tokens(), total_requests = stats.total_requests(), elapsed_secs, nodes_per_sec, "Pipeline completed" ); } tracing::Span::current().record("total_nodes", total_nodes); Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::persist::MemoryStorage; use mockall::Sequence; use swiftide_core::indexing::*; /// Tests a simple run of the indexing pipeline. #[test_log::test(tokio::test)] async fn test_simple_run() { let mut loader = MockLoader::new(); let mut transformer = MockTransformer::new(); let mut batch_transformer = MockBatchableTransformer::new(); let mut chunker = MockChunkerTransformer::new(); let mut storage = MockPersist::new(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| vec![Ok(Node::default())].into()); transformer.expect_transform_node().returning(|mut node| { node.chunk = "transformed".to_string(); Ok(node) }); transformer.expect_concurrency().returning(|| None); transformer.expect_name().returning(|| "transformer"); batch_transformer .expect_batch_transform() .times(1) .in_sequence(&mut seq) .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok))); batch_transformer.expect_concurrency().returning(|| None); batch_transformer.expect_name().returning(|| "transformer"); batch_transformer.expect_batch_size().returning(|| None); chunker .expect_transform_node() .times(1) .in_sequence(&mut seq) .returning(|node| { let mut nodes = vec![]; for i in 0..3 { let mut node = node.clone(); node.chunk = format!("transformed_chunk_{i}"); nodes.push(Ok(node)); } nodes.into() }); chunker.expect_concurrency().returning(|| None); chunker.expect_name().returning(|| "chunker"); storage.expect_setup().returning(|| Ok(())); storage.expect_batch_size().returning(|| None); storage .expect_store() .times(3) .in_sequence(&mut seq) .withf(|node| node.chunk.starts_with("transformed_chunk_")) .returning(Ok); storage.expect_name().returning(|| "storage"); let pipeline = Pipeline::from_loader(loader) .then(transformer) .then_in_batch(batch_transformer) .then_chunk(chunker) .then_store_with(storage); pipeline.run().await.unwrap(); } #[tokio::test] async fn test_skipping_errors() { let mut loader = MockLoader::new(); let mut transformer = MockTransformer::new(); let mut storage = MockPersist::new(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| vec![Ok(Node::default())].into()); transformer .expect_transform_node() .returning(|_node| Err(anyhow::anyhow!("Error transforming node"))); transformer.expect_concurrency().returning(|| None); transformer.expect_name().returning(|| "mock"); storage.expect_setup().returning(|| Ok(())); storage.expect_batch_size().returning(|| None); storage.expect_store().times(0).returning(Ok); let pipeline = Pipeline::from_loader(loader) .then(transformer) .then_store_with(storage) .filter_errors(); pipeline.run().await.unwrap(); } #[tokio::test] async fn test_concurrent_calls_with_simple_transformer() { let mut loader = MockLoader::new(); let mut transformer = MockTransformer::new(); let mut storage = MockPersist::new(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| { vec![ Ok(Node::default()), Ok(Node::default()), Ok(Node::default()), ] .into() }); transformer .expect_transform_node() .times(3) .in_sequence(&mut seq) .returning(|mut node| { node.chunk = "transformed".to_string(); Ok(node) }); transformer.expect_concurrency().returning(|| Some(3)); transformer.expect_name().returning(|| "transformer"); storage.expect_setup().returning(|| Ok(())); storage.expect_batch_size().returning(|| None); storage.expect_store().times(3).returning(Ok); storage.expect_name().returning(|| "storage"); let pipeline = Pipeline::from_loader(loader) .then(transformer) .then_store_with(storage); pipeline.run().await.unwrap(); } #[tokio::test] async fn test_arbitrary_closures_as_transformer() { let mut loader = MockLoader::new(); let transformer = |node: TextNode| { let mut node = node; node.chunk = "transformed".to_string(); Ok(node) }; let storage = MemoryStorage::default(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| vec![Ok(TextNode::default())].into()); let pipeline = Pipeline::from_loader(loader) .then(transformer) .then_store_with(storage.clone()); pipeline.run().await.unwrap(); dbg!(storage.clone()); let processed_node = storage.get("0").await.unwrap(); assert_eq!(processed_node.chunk, "transformed"); } #[tokio::test] async fn test_arbitrary_closures_as_batch_transformer() { let mut loader = MockLoader::new(); let batch_transformer = |nodes: Vec| { IndexingStream::iter(nodes.into_iter().map(|mut node| { node.chunk = "transformed".to_string(); Ok(node) })) }; let storage = MemoryStorage::default(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| vec![Ok(TextNode::default())].into()); let pipeline = Pipeline::from_loader(loader) .then_in_batch(batch_transformer) .then_store_with(storage.clone()); pipeline.run().await.unwrap(); dbg!(storage.clone()); let processed_node = storage.get("0").await.unwrap(); assert_eq!(processed_node.chunk, "transformed"); } #[tokio::test] async fn test_filter_closure() { let mut loader = MockLoader::new(); let storage = MemoryStorage::default(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| { vec![ Ok(TextNode::default()), Ok(TextNode::new("skip")), Ok(TextNode::default()), ] .into() }); let pipeline = Pipeline::from_loader(loader) .filter(|result| { let node = result.as_ref().unwrap(); node.chunk != "skip" }) .then_store_with(storage.clone()); pipeline.run().await.unwrap(); let nodes = storage.get_all().await; assert_eq!(nodes.len(), 2); } #[test_log::test(tokio::test)] async fn test_split_and_merge() { let mut loader = MockLoader::new(); let storage = MemoryStorage::default(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| { vec![ Ok(TextNode::default()), Ok(TextNode::new("will go left")), Ok(TextNode::default()), ] .into() }); let pipeline = Pipeline::from_loader(loader); let (mut left, mut right) = pipeline.split_by(|node| { if let Ok(node) = node { node.chunk.starts_with("will go left") } else { false } }); // change the chunk to 'left' left = left .then(move |mut node: TextNode| { node.chunk = "left".to_string(); Ok(node) }) .log_all(); right = right.then(move |mut node: TextNode| { node.chunk = "right".to_string(); Ok(node) }); left.merge(right) .then_store_with(storage.clone()) .run() .await .unwrap(); dbg!(storage.clone()); let all_nodes = storage.get_all_values().await; assert_eq!( all_nodes.iter().filter(|node| node.chunk == "left").count(), 1 ); assert_eq!( all_nodes .iter() .filter(|node| node.chunk == "right") .count(), 2 ); } #[tokio::test] async fn test_all_steps_should_work_as_dyn_box() { let mut loader = MockLoader::new(); loader .expect_into_stream_boxed() .returning(|| vec![Ok(TextNode::default())].into()); let mut transformer = MockTransformer::new(); transformer.expect_transform_node().returning(Ok); transformer.expect_concurrency().returning(|| None); transformer.expect_name().returning(|| "mock"); let mut batch_transformer = MockBatchableTransformer::new(); batch_transformer .expect_batch_transform() .returning(std::convert::Into::into); batch_transformer.expect_concurrency().returning(|| None); batch_transformer.expect_name().returning(|| "mock"); let mut chunker = MockChunkerTransformer::new(); chunker .expect_transform_node() .returning(|node| vec![node].into()); chunker.expect_concurrency().returning(|| None); chunker.expect_name().returning(|| "mock"); let mut storage = MockPersist::new(); storage.expect_setup().returning(|| Ok(())); storage.expect_store().returning(Ok); storage.expect_batch_size().returning(|| None); storage.expect_name().returning(|| "mock"); let pipeline = Pipeline::from_loader(Box::new(loader) as Box>) .then(Box::new(transformer) as Box>) .then_in_batch(Box::new(batch_transformer) as Box>) .then_chunk(Box::new(chunker) as Box>) .then_store_with(Box::new(storage) as Box>); pipeline.run().await.unwrap(); } #[test_log::test(tokio::test)] async fn test_pipeline_statistics() { let mut loader = MockLoader::new(); let mut storage = MockPersist::new(); let mut seq = Sequence::new(); loader .expect_into_stream() .times(1) .in_sequence(&mut seq) .returning(|| { vec![ Ok(TextNode::default()), Ok(TextNode::default()), Ok(TextNode::default()), ] .into() }); storage.expect_setup().returning(|| Ok(())); storage.expect_batch_size().returning(|| None); storage.expect_store().times(3).returning(Ok); storage.expect_name().returning(|| "storage"); let pipeline = Pipeline::from_loader(loader).then_store_with(storage); // Test that we can access stats before running let initial_stats = pipeline.stats(); assert_eq!(initial_stats.nodes_processed, 0); assert_eq!(initial_stats.nodes_stored, 0); pipeline.run().await.unwrap(); // After running, stats should be updated (access via the moved pipeline would not work, // but we verify the internal behavior through the run method's logging) } #[test_log::test(tokio::test)] async fn test_stats_collector_access() { let mut loader = MockLoader::new(); let storage = MemoryStorage::default(); loader .expect_into_stream() .returning(|| vec![Ok(TextNode::default()), Ok(TextNode::default())].into()); let pipeline = Pipeline::from_loader(loader).then_store_with(storage.clone()); // Access the stats collector let collector = pipeline.stats_collector(); // Record some token usage manually (simulating what transformers would do) collector.record_token_usage("gpt-4", 100, 50); collector.record_token_usage("gpt-3.5", 50, 25); let stats = collector.get_stats(); assert_eq!(stats.total_requests(), 2); assert_eq!(stats.total_tokens(), 225); // Run the pipeline pipeline.run().await.unwrap(); // Verify storage has the nodes let nodes = storage.get_all().await; assert_eq!(nodes.len(), 2); } } ================================================ FILE: swiftide-indexing/src/transformers/chunk_markdown.rs ================================================ //! Chunk markdown content into smaller pieces use std::sync::Arc; use async_trait::async_trait; use derive_builder::Builder; use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::TextNode}; use text_splitter::{Characters, ChunkConfig, MarkdownSplitter}; const DEFAULT_MAX_CHAR_SIZE: usize = 2056; #[derive(Clone, Builder)] #[builder(setter(strip_option))] /// A transformer that chunks markdown content into smaller pieces. /// /// The transformer will split the markdown content into smaller pieces based on the specified /// `max_characters` or `range` of characters. /// /// For further customization, you can use the builder to create a custom splitter. /// /// Technically that might work with every splitter `text_splitter` provides. pub struct ChunkMarkdown { /// Defaults to `None`. If you use a splitter that is resource heavy, this parameter can be /// tuned. #[builder(default)] concurrency: Option, /// Optional maximum number of characters per chunk. /// /// Defaults to [`DEFAULT_MAX_CHAR_SIZE`]. #[builder(default = "DEFAULT_MAX_CHAR_SIZE")] max_characters: usize, /// A range of minimum and maximum characters per chunk. /// /// Chunks smaller than the range min will be ignored. `max_characters` will be ignored if this /// is set. /// /// If you provide a custom chunker with a range, you might want to set the range as well. /// /// Defaults to 0..[`max_characters`] #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")] range: std::ops::Range, /// The markdown splitter from [`text_splitter`] /// /// Defaults to a new [`MarkdownSplitter`] with the specified `max_characters`. #[builder(setter(into), default = "self.default_client()")] chunker: Arc>, } impl std::fmt::Debug for ChunkMarkdown { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ChunkMarkdown") .field("concurrency", &self.concurrency) .field("max_characters", &self.max_characters) .field("range", &self.range) .finish() } } impl Default for ChunkMarkdown { fn default() -> Self { Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE) } } impl ChunkMarkdown { pub fn builder() -> ChunkMarkdownBuilder { ChunkMarkdownBuilder::default() } /// Create a new transformer with a maximum number of characters per chunk. #[allow(clippy::missing_panics_doc)] pub fn from_max_characters(max_characters: usize) -> Self { Self::builder() .max_characters(max_characters) .build() .expect("Cannot fail") } /// Create a new transformer with a range of characters per chunk. /// /// Chunks smaller than the range will be ignored. #[allow(clippy::missing_panics_doc)] pub fn from_chunk_range(range: std::ops::Range) -> Self { Self::builder().range(range).build().expect("Cannot fail") } /// Set the number of concurrent chunks to process. #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = Some(concurrency); self } fn min_size(&self) -> usize { self.range.start } } impl ChunkMarkdownBuilder { fn default_client(&self) -> Arc> { let chunk_config: ChunkConfig = self .range .clone() .map(ChunkConfig::::from) .or_else(|| self.max_characters.map(Into::into)) .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into()); Arc::new(MarkdownSplitter::new(chunk_config)) } } #[async_trait] impl ChunkerTransformer for ChunkMarkdown { type Input = String; type Output = String; #[tracing::instrument(skip_all)] async fn transform_node(&self, node: TextNode) -> IndexingStream { let chunks = self .chunker .chunks(&node.chunk) .filter_map(|chunk| { let trim = chunk.trim(); if trim.is_empty() || trim.len() < self.min_size() { None } else { Some(chunk.to_string()) } }) .collect::>(); IndexingStream::iter( chunks .into_iter() .map(move |chunk| TextNode::build_from_other(&node).chunk(chunk).build()), ) } fn concurrency(&self) -> Option { self.concurrency } } #[cfg(test)] mod test { use super::*; use futures_util::stream::TryStreamExt; const MARKDOWN: &str = r" # Hello, world! This is a test markdown document. ## Section 1 This is a paragraph. ## Section 2 This is another paragraph. "; #[tokio::test] async fn test_transforming_with_max_characters_and_trimming() { let chunker = ChunkMarkdown::from_max_characters(40); let node = TextNode::new(MARKDOWN.to_string()); let nodes: Vec = chunker .transform_node(node) .await .try_collect() .await .unwrap(); dbg!(&nodes.iter().map(|n| n.chunk.clone()).collect::>()); for line in MARKDOWN.lines().filter(|line| !line.trim().is_empty()) { nodes .iter() .find(|node| node.chunk == line.trim()) .unwrap_or_else(|| panic!("Line not found: {line}")); } assert_eq!(nodes.len(), 6); } #[tokio::test] async fn test_always_within_range() { let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)]; for range in ranges { let chunker = ChunkMarkdown::from_chunk_range(range.clone()); let node = TextNode::new(MARKDOWN.to_string()); let nodes: Vec = chunker .transform_node(node) .await .try_collect() .await .unwrap(); // Assert all nodes chunk length within the range assert!( nodes.iter().all(|node| { let len = node.chunk.len(); range.contains(&len) }), "{:?}, {:?}", range, nodes.iter().filter(|node| { let len = node.chunk.len(); !range.contains(&len) }) ); } } #[test] fn test_builder() { ChunkMarkdown::builder() .chunker(MarkdownSplitter::new(40)) .concurrency(10) .range(10..20) .build() .unwrap(); } } ================================================ FILE: swiftide-indexing/src/transformers/chunk_text.rs ================================================ //! Chunk text content into smaller pieces use std::sync::Arc; use async_trait::async_trait; use derive_builder::Builder; use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::TextNode}; use text_splitter::{Characters, ChunkConfig, TextSplitter}; const DEFAULT_MAX_CHAR_SIZE: usize = 2056; #[derive(Debug, Clone, Builder)] #[builder(setter(strip_option))] /// A transformer that chunks text content into smaller pieces. /// /// The transformer will split the text content into smaller pieces based on the specified /// `max_characters` or `range` of characters. /// /// For further customization, you can use the builder to create a custom splitter. Uses /// `text_splitter` under the hood. /// /// Technically that might work with every splitter `text_splitter` provides. pub struct ChunkText { /// The max number of concurrent chunks to process. /// /// Defaults to `None`. If you use a splitter that is resource heavy, this parameter can be /// tuned. #[builder(default)] concurrency: Option, /// Optional maximum number of characters per chunk. /// /// Defaults to [`DEFAULT_MAX_CHAR_SIZE`]. #[builder(default = "DEFAULT_MAX_CHAR_SIZE")] #[allow(dead_code)] max_characters: usize, /// A range of minimum and maximum characters per chunk. /// /// Chunks smaller than the range min will be ignored. `max_characters` will be ignored if this /// is set. /// /// If you provide a custom chunker with a range, you might want to set the range as well. /// /// Defaults to 0..[`max_characters`] #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")] range: std::ops::Range, /// The text splitter from [`text_splitter`] /// /// Defaults to a new [`TextSplitter`] with the specified `max_characters`. #[builder(setter(into), default = "self.default_client()")] chunker: Arc>, } impl Default for ChunkText { fn default() -> Self { Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE) } } impl ChunkText { pub fn builder() -> ChunkTextBuilder { ChunkTextBuilder::default() } /// Create a new transformer with a maximum number of characters per chunk. #[allow(clippy::missing_panics_doc)] pub fn from_max_characters(max_characters: usize) -> Self { Self::builder() .max_characters(max_characters) .build() .expect("Cannot fail") } /// Create a new transformer with a range of characters per chunk. /// /// Chunks smaller than the range will be ignored. #[allow(clippy::missing_panics_doc)] pub fn from_chunk_range(range: std::ops::Range) -> Self { Self::builder().range(range).build().expect("Cannot fail") } /// Set the number of concurrent chunks to process. #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = Some(concurrency); self } fn min_size(&self) -> usize { self.range.start } } impl ChunkTextBuilder { fn default_client(&self) -> Arc> { let chunk_config: ChunkConfig = self .range .clone() .map(ChunkConfig::::from) .or_else(|| self.max_characters.map(Into::into)) .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into()); Arc::new(TextSplitter::new(chunk_config)) } } #[async_trait] impl ChunkerTransformer for ChunkText { type Input = String; type Output = String; #[tracing::instrument(skip_all, name = "transformers.chunk_text")] async fn transform_node(&self, node: TextNode) -> IndexingStream { let chunks = self .chunker .chunks(&node.chunk) .filter_map(|chunk| { let trim = chunk.trim(); if trim.is_empty() || trim.len() < self.min_size() { None } else { Some(chunk.to_string()) } }) .collect::>(); IndexingStream::iter( chunks .into_iter() .map(move |chunk| TextNode::build_from_other(&node).chunk(chunk).build()), ) } fn concurrency(&self) -> Option { self.concurrency } } #[cfg(test)] mod test { use super::*; use futures_util::stream::TryStreamExt; const TEXT: &str = r" This is a text. This is a paragraph. This is another paragraph. "; #[tokio::test] async fn test_transforming_with_max_characters_and_trimming() { let chunker = ChunkText::from_max_characters(40); let node = TextNode::new(TEXT.to_string()); let nodes: Vec = chunker .transform_node(node) .await .try_collect() .await .unwrap(); for line in TEXT.lines().filter(|line| !line.trim().is_empty()) { assert!(nodes.iter().any(|node| node.chunk == line.trim())); } assert_eq!(nodes.len(), 3); } #[tokio::test] async fn test_always_within_range() { let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)]; for range in ranges { let chunker = ChunkText::from_chunk_range(range.clone()); let node = TextNode::new(TEXT.to_string()); let nodes: Vec = chunker .transform_node(node) .await .try_collect() .await .unwrap(); // Assert all nodes chunk length within the range assert!( nodes.iter().all(|node| { let len = node.chunk.len(); range.contains(&len) }), "{:?}, {:?}", range, nodes.iter().filter(|node| { let len = node.chunk.len(); !range.contains(&len) }) ); } } #[test] fn test_builder() { ChunkText::builder() .chunker(text_splitter::TextSplitter::new(40)) .concurrency(10) .range(10..20) .build() .unwrap(); } } ================================================ FILE: swiftide-indexing/src/transformers/embed.rs ================================================ //! Generic embedding transformer use std::{collections::VecDeque, sync::Arc}; use anyhow::bail; use async_trait::async_trait; use swiftide_core::{ BatchableTransformer, EmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults, indexing::{IndexingStream, TextNode}, }; /// A transformer that can generate embeddings for an `TextNode` /// /// This file defines the `Embed` struct and its implementation of the `BatchableTransformer` trait. #[derive(Clone)] pub struct Embed { model: Arc, concurrency: Option, batch_size: Option, } impl std::fmt::Debug for Embed { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Embed") .field("concurrency", &self.concurrency) .field("batch_size", &self.batch_size) .finish() } } impl Embed { /// Creates a new instance of the `Embed` transformer. /// /// # Parameters /// /// * `model` - An embedding model that implements the `EmbeddingModel` trait. /// /// # Returns /// /// A new instance of `Embed`. pub fn new(model: impl EmbeddingModel + 'static) -> Self { Self { model: Arc::new(model), concurrency: None, batch_size: None, } } #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = Some(concurrency); self } /// Sets the batch size for the transformer. /// If the batch size is not set, the transformer will use the default batch size set by the /// pipeline # Parameters /// /// * `batch_size` - The batch size to use for the transformer. /// /// # Returns /// /// A new instance of `Embed`. #[must_use] pub fn with_batch_size(mut self, batch_size: usize) -> Self { self.batch_size = Some(batch_size); self } } impl WithBatchIndexingDefaults for Embed {} impl WithIndexingDefaults for Embed {} #[async_trait] impl BatchableTransformer for Embed { type Input = String; type Output = String; /// Transforms a batch of `TextNode` objects by generating embeddings for them. /// /// # Parameters /// /// * `nodes` - A vector of `TextNode` objects to be transformed. /// /// # Returns /// /// An `IndexingStream` containing the transformed `TextNode` objects with their embeddings. /// /// # Errors /// /// If the embedding process fails, the function returns a stream with the error. #[tracing::instrument(skip_all, name = "transformers.embed")] async fn batch_transform(&self, mut nodes: Vec) -> IndexingStream { // TODO: We should drop chunks that go over the token limit of the EmbedModel // EmbeddedFields grouped by node stored in order of processed nodes. let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len()); // Embeddable data of every node stored in order of processed nodes. let embeddables_data = nodes .iter_mut() .fold(Vec::new(), |mut embeddables_data, node| { let embeddables = node.as_embeddables(); let mut embeddables_keys = Vec::with_capacity(embeddables.len()); for (embeddable_key, embeddable_data) in embeddables { embeddables_keys.push(embeddable_key); embeddables_data.push(embeddable_data); } embeddings_keys_groups.push_back(embeddables_keys); embeddables_data }); // Embeddings vectors of every node stored in order of processed nodes. let mut embeddings = match self.model.embed(embeddables_data).await { Ok(embeddngs) => VecDeque::from(embeddngs), Err(err) => return IndexingStream::iter(vec![Err(err.into())]), }; // Iterator of nodes with embeddings vectors map. let nodes_iter = nodes.into_iter().map(move |mut node| { let Some(embedding_keys) = embeddings_keys_groups.pop_front() else { bail!("Missing embedding data"); }; node.vectors = embedding_keys .into_iter() .map(|embedded_field| { embeddings .pop_front() .map(|embedding| (embedded_field, embedding)) }) .collect(); Ok(node) }); IndexingStream::iter(nodes_iter) } fn concurrency(&self) -> Option { self.concurrency } fn batch_size(&self) -> Option { self.batch_size } } #[cfg(test)] mod tests { use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, TextNode}; use swiftide_core::{BatchableTransformer, MockEmbeddingModel}; use super::Embed; use futures_util::StreamExt; use mockall::predicate::*; use test_case::test_case; use swiftide_core::chat_completion::errors::LanguageModelError; #[derive(Clone)] struct TestData<'a> { pub embed_mode: EmbedMode, pub chunk: &'a str, pub metadata: Metadata, pub expected_embedables: Vec<&'a str>, pub expected_vectors: Vec<(EmbeddedField, Vec)>, } #[test_case(vec![ TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_1", metadata: Metadata::from([("meta_1", "prompt_1")]), expected_embedables: vec!["meta_1: prompt_1\nchunk_1"], expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])] }, TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_2", metadata: Metadata::from([("meta_2", "prompt_2")]), expected_embedables: vec!["meta_2: prompt_2\nchunk_2"], expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])] } ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")] #[test_case(vec![ TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_1", metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![10f32]), (EmbeddedField::Metadata("meta_1".into()), vec![11f32]) ] }, TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_2", metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![20f32]), (EmbeddedField::Metadata("meta_2".into()), vec![21f32]) ] } ]; "Multiple nodes EmbedMode::PerField with metadata.")] #[test_case(vec![ TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), (EmbeddedField::Chunk, vec![11f32]), (EmbeddedField::Metadata("meta_1".into()), vec![12f32]) ] }, TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), (EmbeddedField::Chunk, vec![21f32]), (EmbeddedField::Metadata("meta_2".into()), vec![22f32]) ] } ]; "Multiple nodes EmbedMode::Both with metadata.")] #[test_case(vec![ TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]), expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), (EmbeddedField::Chunk, vec![11f32]), (EmbeddedField::Metadata("meta_10".into()), vec![12f32]), (EmbeddedField::Metadata("meta_11".into()), vec![13f32]), (EmbeddedField::Metadata("meta_12".into()), vec![14f32]), ] }, TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]), expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), (EmbeddedField::Chunk, vec![21f32]), (EmbeddedField::Metadata("meta_20".into()), vec![22f32]), (EmbeddedField::Metadata("meta_21".into()), vec![23f32]), (EmbeddedField::Metadata("meta_22".into()), vec![24f32]) ] } ]; "Multiple nodes EmbedMode::Both with multiple metadata.")] #[test_case(vec![]; "No ingestion nodes")] #[tokio::test] async fn batch_transform(test_data: Vec>) { let test_nodes: Vec = test_data .iter() .map(|data| { TextNode::builder() .chunk(data.chunk) .metadata(data.metadata.clone()) .embed_mode(data.embed_mode) .build() .unwrap() }) .collect(); let expected_nodes: Vec = test_nodes .clone() .into_iter() .zip(test_data.iter()) .map(|(mut expected_node, test_data)| { expected_node.vectors = Some(test_data.expected_vectors.iter().cloned().collect()); expected_node }) .collect(); let expected_embeddables_batch = test_data .clone() .iter() .flat_map(|d| &d.expected_embedables) .map(ToString::to_string) .collect::>(); let expected_vectors_batch: Vec> = test_data .clone() .iter() .flat_map(|d| d.expected_vectors.iter().map(|(_, v)| v).cloned()) .collect(); let mut model_mock = MockEmbeddingModel::new(); model_mock .expect_embed() .withf(move |embeddables| expected_embeddables_batch.eq(embeddables)) .times(1) .returning_st(move |_| Ok(expected_vectors_batch.clone())); let embed = Embed::new(model_mock); let mut stream = embed.batch_transform(test_nodes).await; for expected_node in expected_nodes { let ingested_node = stream .next() .await .expect("IngestionStream has same length as expected_nodes") .expect("Is OK"); debug_assert_eq!(ingested_node, expected_node); } } #[tokio::test] async fn test_returns_error_properly_if_embed_fails() { let test_nodes = vec![TextNode::new("chunk")]; let mut model_mock = MockEmbeddingModel::new(); model_mock .expect_embed() .times(1) .returning(|_| Err(LanguageModelError::PermanentError("error".into()))); let embed = Embed::new(model_mock); let mut stream = embed.batch_transform(test_nodes).await; let error = stream .next() .await .expect("IngestionStream has same length as expected_nodes") .expect_err("Is Err"); assert_eq!(error.to_string(), "Permanent error: error"); } } ================================================ FILE: swiftide-indexing/src/transformers/metadata_keywords.rs ================================================ //! Extract keywords from a node and add them as metadata //! This module defines the `MetadataKeywords` struct and its associated methods, //! which are used for generating metadata in the form of keywords //! for a given text. It interacts with a client (e.g., `OpenAI`) to generate //! the keywords based on the text chunk in a `TextNode`. use anyhow::Result; use async_trait::async_trait; use swiftide_core::{Transformer, indexing::TextNode}; /// `MetadataKeywords` is responsible for generating keywords /// for a given text chunk. It uses a templated prompt to interact with a client /// that implements the `SimplePrompt` trait. #[swiftide_macros::indexing_transformer( default_prompt_file = "prompts/metadata_keywords.prompt.md", metadata_field_name = "Keywords" )] pub struct MetadataKeywords {} #[async_trait] impl Transformer for MetadataKeywords { type Input = String; type Output = String; /// Transforms an `TextNode` by extracting a keywords /// based on the text chunk within the node. /// /// # Arguments /// /// * `node` - The `TextNode` containing the text chunk to process. /// /// # Returns /// /// A `Result` containing the transformed `TextNode` with added metadata, /// or an error if the transformation fails. /// /// # Errors /// /// This function will return an error if the client fails to generate /// a keywords from the provided prompt. #[tracing::instrument(skip_all, name = "transformers.metadata_keywords")] async fn transform_node(&self, mut node: TextNode) -> Result { let prompt = self.prompt_template.clone().with_node(&node); let response = self.prompt(prompt).await?; node.metadata.insert(NAME, response); Ok(node) } fn concurrency(&self) -> Option { self.concurrency } } #[cfg(test)] mod test { use swiftide_core::MockSimplePrompt; use super::*; #[test_log::test(tokio::test)] async fn test_template() { let template = default_prompt(); let prompt = template.clone().with_node(&TextNode::new("test")); insta::assert_snapshot!(prompt.render().unwrap()); } #[tokio::test] async fn test_metadata_keywords() { let mut client = MockSimplePrompt::new(); client .expect_prompt() .returning(|_| Ok("important,keywords".to_string())); let transformer = MetadataKeywords::builder().client(client).build().unwrap(); let node = TextNode::new("Some text"); let result = transformer.transform_node(node).await.unwrap(); assert_eq!( result.metadata.get("Keywords").unwrap(), "important,keywords" ); } } ================================================ FILE: swiftide-indexing/src/transformers/metadata_qa_text.rs ================================================ //! Generates questions and answers from a given text chunk and adds them as metadata. //! This module defines the `MetadataQAText` struct and its associated methods, //! which are used for generating metadata in the form of questions and answers //! from a given text. It interacts with a client (e.g., `OpenAI`) to generate //! these questions and answers based on the text chunk in an `TextNode`. use anyhow::Result; use async_trait::async_trait; use swiftide_core::{Transformer, indexing::TextNode}; /// `MetadataQAText` is responsible for generating questions and answers /// from a given text chunk. It uses a templated prompt to interact with a client /// that implements the `SimplePrompt` trait. #[swiftide_macros::indexing_transformer( metadata_field_name = "Questions and Answers (text)", default_prompt_file = "prompts/metadata_qa_text.prompt.md" )] pub struct MetadataQAText { #[builder(default = "5")] num_questions: usize, } #[async_trait] impl Transformer for MetadataQAText { type Input = String; type Output = String; /// Transforms an `TextNode` by generating questions and answers /// based on the text chunk within the node. /// /// # Arguments /// /// * `node` - The `TextNode` containing the text chunk to process. /// /// # Returns /// /// A `Result` containing the transformed `TextNode` with added metadata, /// or an error if the transformation fails. /// /// # Errors /// /// This function will return an error if the client fails to generate /// questions and answers from the provided prompt. #[tracing::instrument(skip_all, name = "transformers.metadata_qa_text")] async fn transform_node(&self, mut node: TextNode) -> Result { let prompt = self .prompt_template .clone() .with_node(&node) .with_context_value("questions", self.num_questions); let response = self.prompt(prompt).await?; node.metadata.insert(NAME, response); Ok(node) } fn concurrency(&self) -> Option { self.concurrency } } #[cfg(test)] mod test { use swiftide_core::MockSimplePrompt; use super::*; #[tokio::test] async fn test_template() { let template = default_prompt(); let prompt = template .clone() .with_node(&TextNode::new("test")) .with_context_value("questions", 5); insta::assert_snapshot!(prompt.render().unwrap()); } #[tokio::test] async fn test_metadata_qacode() { let mut client = MockSimplePrompt::new(); client .expect_prompt() .returning(|_| Ok("Q1: Hello\nA1: World".to_string())); let transformer = MetadataQAText::builder().client(client).build().unwrap(); let node = TextNode::new("Some text"); let result = transformer.transform_node(node).await.unwrap(); assert_eq!( result.metadata.get("Questions and Answers (text)").unwrap(), "Q1: Hello\nA1: World" ); } } ================================================ FILE: swiftide-indexing/src/transformers/metadata_summary.rs ================================================ //! Generate a summary and adds it as metadata //! This module defines the `MetadataSummary` struct and its associated methods, //! which are used for generating metadata in the form of a summary //! for a given text. It interacts with a client (e.g., `OpenAI`) to generate //! the summary based on the text chunk in an `TextNode`. use anyhow::Result; use async_trait::async_trait; use swiftide_core::{Transformer, indexing::TextNode}; /// `MetadataSummary` is responsible for generating a summary /// for a given text chunk. It uses a templated prompt to interact with a client /// that implements the `SimplePrompt` trait. #[swiftide_macros::indexing_transformer( metadata_field_name = "Summary", default_prompt_file = "prompts/metadata_summary.prompt.md" )] pub struct MetadataSummary {} #[async_trait] impl Transformer for MetadataSummary { type Input = String; type Output = String; /// Transforms an `TextNode` by extracting a summary /// based on the text chunk within the node. /// /// # Arguments /// /// * `node` - The `TextNode` containing the text chunk to process. /// /// # Returns /// /// A `Result` containing the transformed `TextNode` with added metadata, /// or an error if the transformation fails. /// /// # Errors /// /// This function will return an error if the client fails to generate /// a summary from the provided prompt. #[tracing::instrument(skip_all, name = "transformers.metadata_summary")] async fn transform_node(&self, mut node: TextNode) -> Result { let prompt = self.prompt_template.clone().with_node(&node); let response = self.prompt(prompt).await?; node.metadata.insert(NAME, response); Ok(node) } fn concurrency(&self) -> Option { self.concurrency } } #[cfg(test)] mod test { use swiftide_core::MockSimplePrompt; use super::*; #[tokio::test] async fn test_template() { let template = default_prompt(); let prompt = template.clone().with_node(&TextNode::new("test")); insta::assert_snapshot!(prompt.render().unwrap()); } #[tokio::test] async fn test_metadata_summary() { let mut client = MockSimplePrompt::new(); client .expect_prompt() .returning(|_| Ok("A Summary".to_string())); let transformer = MetadataSummary::builder().client(client).build().unwrap(); let node = TextNode::new("Some text"); let result = transformer.transform_node(node).await.unwrap(); assert_eq!(result.metadata.get("Summary").unwrap(), "A Summary"); } } ================================================ FILE: swiftide-indexing/src/transformers/metadata_title.rs ================================================ //! Generate a title and adds it as metadata //! This module defines the `MetadataTitle` struct and its associated methods, //! which are used for generating metadata in the form of a title //! for a given text. It interacts with a client (e.g., `OpenAI`) to generate //! these questions and answers based on the text chunk in an `TextNode`. use anyhow::Result; use async_trait::async_trait; use swiftide_core::{Transformer, indexing::TextNode}; /// `MetadataTitle` is responsible for generating a title /// for a given text chunk. It uses a templated prompt to interact with a client /// that implements the `SimplePrompt` trait. #[swiftide_macros::indexing_transformer( metadata_field_name = "Title", default_prompt_file = "prompts/metadata_title.prompt.md" )] pub struct MetadataTitle {} #[async_trait] impl Transformer for MetadataTitle { type Input = String; type Output = String; /// Transforms an `TextNode` by generating questions and answers /// based on the text chunk within the node. /// /// # Arguments /// /// * `node` - The `TextNode` containing the text chunk to process. /// /// # Returns /// /// A `Result` containing the transformed `TextNode` with added metadata, /// or an error if the transformation fails. /// /// # Errors /// /// This function will return an error if the client fails to generate /// questions and answers from the provided prompt. #[tracing::instrument(skip_all, name = "transformers.metadata_title")] async fn transform_node(&self, mut node: TextNode) -> Result { let prompt = self.prompt_template.clone().with_node(&node); let response = self.prompt(prompt).await?; node.metadata.insert(NAME, response); Ok(node) } fn concurrency(&self) -> Option { self.concurrency } } #[cfg(test)] mod test { use swiftide_core::MockSimplePrompt; use super::*; #[test_log::test(tokio::test)] async fn test_template() { let template = default_prompt(); let prompt = template.clone().with_node(&TextNode::new("test")); insta::assert_snapshot!(prompt.render().unwrap()); } #[tokio::test] async fn test_metadata_title() { let mut client = MockSimplePrompt::new(); client .expect_prompt() .returning(|_| Ok("A Title".to_string())); let transformer = MetadataTitle::builder().client(client).build().unwrap(); let node = TextNode::new("Some text"); let result = transformer.transform_node(node).await.unwrap(); assert_eq!(result.metadata.get("Title").unwrap(), "A Title"); } } ================================================ FILE: swiftide-indexing/src/transformers/mod.rs ================================================ //! Various transformers for chunking, embedding and transforming data //! //! These transformers are generic over their implementation and many require a //! swiftide integration to be configured. //! //! Transformers that prompt have a default prompt configured. Prompts can be customized //! and tailored, supporting Jinja style templating based on [tera](https://docs.rs/tera/latest/tera/). //! //! See [`swiftide_core::prompt::Prompt`] and [`swiftide_core::template::Template`] pub mod chunk_markdown; pub mod chunk_text; pub mod embed; pub mod metadata_keywords; pub mod metadata_qa_text; pub mod metadata_summary; pub mod metadata_title; pub mod sparse_embed; pub use chunk_markdown::ChunkMarkdown; pub use chunk_text::ChunkText; pub use embed::Embed; pub use metadata_keywords::MetadataKeywords; pub use metadata_qa_text::MetadataQAText; pub use metadata_summary::MetadataSummary; pub use metadata_title::MetadataTitle; pub use sparse_embed::SparseEmbed; ================================================ FILE: swiftide-indexing/src/transformers/prompts/metadata_keywords.prompt.md ================================================ # Task Your task is to generate a descriptive, concise keywords for the given text # Constraints - Only respond in the example format - Respond with a keywords that are representative of the text - Only include keywords that are literally included in the text - Respond with a comma-separated list of keywords # Example Respond in the following example format and do not include anything else: ``` , ``` # Text ``` {{ node.chunk }} ``` ================================================ FILE: swiftide-indexing/src/transformers/prompts/metadata_qa_text.prompt.md ================================================ # Task Your task is to generate questions and answers for the given text. Given that somebody else might ask questions about the text, consider things like: - What does this text do? - What other internal parts does the text use? - Does this text have any dependencies? - What are some potential use cases for this text? - ... and so on # Constraints - Generate at most {{questions}} questions and answers. - Only respond in the example format - Only respond with questions and answers that can be derived from the text. # Example Respond in the following example format and do not include anything else: ``` Q1: What is the capital of France? A1: Paris. ``` # text ``` {{node.chunk}} ``` ================================================ FILE: swiftide-indexing/src/transformers/prompts/metadata_summary.prompt.md ================================================ # Task Your task is to generate a descriptive, concise summary for the given text # Constraints - Only respond in the example format - Respond with a summary that is accurate and descriptive without fluff - Only include information that is included in the text # Example Respond in the following example format and do not include anything else: ```

``` # Text ``` {{node.chunk}} ``` ================================================ FILE: swiftide-indexing/src/transformers/prompts/metadata_title.prompt.md ================================================ # Task Your task is to generate a descriptive, concise title for the given text # Constraints - Only respond in the example format - Respond with a title that is accurate and descriptive without fluff # Example Respond in the following example format and do not include anything else: ``` ``` # Text ``` {{node.chunk}} ``` ================================================ FILE: swiftide-indexing/src/transformers/snapshots/swiftide_indexing__transformers__compress_code_outline__test__compress_code_template.snap ================================================ --- source: swiftide-indexing/src/transformers/compress_code_outline.rs expression: prompt.render().await.unwrap() --- # Filtering Code Outline Your task is to filter the given file outline to the code chunk provided. The goal is to provide a context that is still contains the lines needed for understanding the code in the chunk whilst leaving out any irrelevant information. ## Constraints * Only use lines from the provided context, do not add any additional information * Ensure that the selection you make is the most appropriate for the code chunk * Make sure you include any definitions or imports that are used in the code chunk * You do not need to repeat the code chunk in your response, it will be appended directly after your response. * Do not use lines that are present in the code chunk ## Code ``` Code using outline ``` ## Outline ``` Relevant Outline ``` ================================================ FILE: swiftide-indexing/src/transformers/snapshots/swiftide_indexing__transformers__metadata_keywords__test__template.snap ================================================ --- source: swiftide-indexing/src/transformers/metadata_keywords.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate a descriptive, concise keywords for the given text # Constraints - Only respond in the example format - Respond with a keywords that are representative of the text - Only include keywords that are literally included in the text - Respond with a comma-separated list of keywords # Example Respond in the following example format and do not include anything else: ``` <keyword>,<other-keyword> ``` # Text ``` test ``` ================================================ FILE: swiftide-indexing/src/transformers/snapshots/swiftide_indexing__transformers__metadata_qa_code__test__template.snap ================================================ --- source: swiftide-indexing/src/transformers/metadata_qa_code.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate questions and answers for the given code. Given that somebody else might ask questions about the code, consider things like: - What does this code do? - What other internal parts does the code use? - Does this code have any dependencies? - What are some potential use cases for this code? - ... and so on # Constraints - Generate only 5 questions and answers. - Only respond in the example format - Only respond with questions and answers that can be derived from the code. # Example Respond in the following example format and do not include anything else: ``` Q1: What does this code do? A1: It transforms strings into integers. Q2: What other internal parts does the code use? A2: A hasher to hash the strings. ``` # Code ``` test ``` ================================================ FILE: swiftide-indexing/src/transformers/snapshots/swiftide_indexing__transformers__metadata_qa_code__test__template_with_outline.snap ================================================ --- source: swiftide-indexing/src/transformers/metadata_qa_code.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate questions and answers for the given code. Given that somebody else might ask questions about the code, consider things like: - What does this code do? - What other internal parts does the code use? - Does this code have any dependencies? - What are some potential use cases for this code? - ... and so on # Constraints - Generate only 5 questions and answers. - Only respond in the example format - Only respond with questions and answers that can be derived from the code. # Example Respond in the following example format and do not include anything else: ``` Q1: What does this code do? A1: It transforms strings into integers. Q2: What other internal parts does the code use? A2: A hasher to hash the strings. ``` ## Outline of the parent file ``` Test outline ``` # Code ``` test ``` ================================================ FILE: swiftide-indexing/src/transformers/snapshots/swiftide_indexing__transformers__metadata_qa_text__test__template.snap ================================================ --- source: swiftide-indexing/src/transformers/metadata_qa_text.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate questions and answers for the given text. Given that somebody else might ask questions about the text, consider things like: - What does this text do? - What other internal parts does the text use? - Does this text have any dependencies? - What are some potential use cases for this text? - ... and so on # Constraints - Generate at most 5 questions and answers. - Only respond in the example format - Only respond with questions and answers that can be derived from the text. # Example Respond in the following example format and do not include anything else: ``` Q1: What is the capital of France? A1: Paris. ``` # text ``` test ``` ================================================ FILE: swiftide-indexing/src/transformers/snapshots/swiftide_indexing__transformers__metadata_summary__test__template.snap ================================================ --- source: swiftide-indexing/src/transformers/metadata_summary.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate a descriptive, concise summary for the given text # Constraints - Only respond in the example format - Respond with a summary that is accurate and descriptive without fluff - Only include information that is included in the text # Example Respond in the following example format and do not include anything else: ``` <summary> ``` # Text ``` test ``` ================================================ FILE: swiftide-indexing/src/transformers/snapshots/swiftide_indexing__transformers__metadata_title__test__template.snap ================================================ --- source: swiftide-indexing/src/transformers/metadata_title.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate a descriptive, concise title for the given text # Constraints - Only respond in the example format - Respond with a title that is accurate and descriptive without fluff # Example Respond in the following example format and do not include anything else: ``` <title> ``` # Text ``` test ``` ================================================ FILE: swiftide-indexing/src/transformers/sparse_embed.rs ================================================ //! Generic embedding transformer use std::{collections::VecDeque, sync::Arc}; use anyhow::bail; use async_trait::async_trait; use swiftide_core::{ BatchableTransformer, SparseEmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults, indexing::{IndexingStream, TextNode}, }; /// A transformer that can generate embeddings for an `TextNode` /// /// This file defines the `SparseEmbed` struct and its implementation of the `BatchableTransformer` /// trait. #[derive(Clone)] pub struct SparseEmbed { embed_model: Arc<dyn SparseEmbeddingModel>, concurrency: Option<usize>, batch_size: Option<usize>, } impl std::fmt::Debug for SparseEmbed { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SparseEmbed") .field("concurrency", &self.concurrency) .finish() } } impl SparseEmbed { /// Creates a new instance of the `SparseEmbed` transformer. /// /// # Parameters /// /// * `model` - An embedding model that implements the `SparseEmbeddingModel` trait. /// /// # Returns /// /// A new instance of `SparseEmbed`. pub fn new(model: impl SparseEmbeddingModel + 'static) -> Self { Self { embed_model: Arc::new(model), concurrency: None, batch_size: None, } } #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = Some(concurrency); self } /// Sets the batch size for the transformer. /// If the batch size is not set, the transformer will use the default batch size set by the /// pipeline # Parameters /// /// * `batch_size` - The batch size to use for the transformer. /// /// # Returns /// /// A new instance of `Embed`. #[must_use] pub fn with_batch_size(mut self, batch_size: usize) -> Self { self.batch_size = Some(batch_size); self } } impl WithBatchIndexingDefaults for SparseEmbed {} impl WithIndexingDefaults for SparseEmbed {} #[async_trait] impl BatchableTransformer for SparseEmbed { type Input = String; type Output = String; /// Transforms a batch of `TextNode` objects by generating embeddings for them. /// /// # Parameters /// /// * `nodes` - A vector of `TextNode` objects to be transformed. /// /// # Returns /// /// An `IndexingStream` containing the transformed `TextNode` objects with their embeddings. /// /// # Errors /// /// If the embedding process fails, the function returns a stream with the error. #[tracing::instrument(skip_all, name = "transformers.embed")] async fn batch_transform(&self, mut nodes: Vec<TextNode>) -> IndexingStream<String> { // TODO: We should drop chunks that go over the token limit of the SparseEmbedModel // EmbeddedFields grouped by node stored in order of processed nodes. let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len()); // SparseEmbeddable data of every node stored in order of processed nodes. let embeddables_data = nodes .iter_mut() .fold(Vec::new(), |mut embeddables_data, node| { let embeddables = node.as_embeddables(); let mut embeddables_keys = Vec::with_capacity(embeddables.len()); for (embeddable_key, embeddable_data) in embeddables { embeddables_keys.push(embeddable_key); embeddables_data.push(embeddable_data); } embeddings_keys_groups.push_back(embeddables_keys); embeddables_data }); // SparseEmbeddings vectors of every node stored in order of processed nodes. let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await { Ok(embeddngs) => VecDeque::from(embeddngs), Err(err) => return IndexingStream::iter(vec![Err(err.into())]), }; // Iterator of nodes with embeddings vectors map. let nodes_iter = nodes.into_iter().map(move |mut node| { let Some(embedding_keys) = embeddings_keys_groups.pop_front() else { bail!("Missing embedding data"); }; node.sparse_vectors = embedding_keys .into_iter() .map(|embedded_field| { embeddings .pop_front() .map(|embedding| (embedded_field, embedding)) }) .collect(); Ok(node) }); IndexingStream::iter(nodes_iter) } fn concurrency(&self) -> Option<usize> { self.concurrency } fn batch_size(&self) -> Option<usize> { self.batch_size } } #[cfg(test)] mod tests { use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, TextNode}; use swiftide_core::{ BatchableTransformer, MockSparseEmbeddingModel, SparseEmbedding, SparseEmbeddings, }; use super::SparseEmbed; use futures_util::StreamExt; use mockall::predicate::*; use test_case::test_case; use swiftide_core::chat_completion::errors::LanguageModelError; #[derive(Clone)] struct TestData<'a> { pub embed_mode: EmbedMode, pub chunk: &'a str, pub metadata: Metadata, pub expected_embedables: Vec<&'a str>, pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>, } #[test_case(vec![ TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_1", metadata: Metadata::from([("meta_1", "prompt_1")]), expected_embedables: vec!["meta_1: prompt_1\nchunk_1"], expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])] }, TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_2", metadata: Metadata::from([("meta_2", "prompt_2")]), expected_embedables: vec!["meta_2: prompt_2\nchunk_2"], expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])] } ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")] #[test_case(vec![ TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_1", metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![10f32]), (EmbeddedField::Metadata("meta_1".into()), vec![11f32]) ] }, TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_2", metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![20f32]), (EmbeddedField::Metadata("meta_2".into()), vec![21f32]) ] } ]; "Multiple nodes EmbedMode::PerField with metadata.")] #[test_case(vec![ TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), (EmbeddedField::Chunk, vec![11f32]), (EmbeddedField::Metadata("meta_1".into()), vec![12f32]) ] }, TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), (EmbeddedField::Chunk, vec![21f32]), (EmbeddedField::Metadata("meta_2".into()), vec![22f32]) ] } ]; "Multiple nodes EmbedMode::Both with metadata.")] #[test_case(vec![ TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]), expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), (EmbeddedField::Chunk, vec![11f32]), (EmbeddedField::Metadata("meta_10".into()), vec![12f32]), (EmbeddedField::Metadata("meta_11".into()), vec![13f32]), (EmbeddedField::Metadata("meta_12".into()), vec![14f32]), ] }, TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]), expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), (EmbeddedField::Chunk, vec![21f32]), (EmbeddedField::Metadata("meta_20".into()), vec![22f32]), (EmbeddedField::Metadata("meta_21".into()), vec![23f32]), (EmbeddedField::Metadata("meta_22".into()), vec![24f32]) ] } ]; "Multiple nodes EmbedMode::Both with multiple metadata.")] #[test_case(vec![]; "No ingestion nodes")] #[tokio::test] async fn batch_transform(test_data: Vec<TestData<'_>>) { let test_nodes: Vec<TextNode> = test_data .iter() .map(|data| { TextNode::builder() .chunk(data.chunk) .metadata(data.metadata.clone()) .embed_mode(data.embed_mode) .build() .unwrap() }) .collect(); let expected_nodes: Vec<TextNode> = test_nodes .clone() .into_iter() .zip(test_data.iter()) .map(|(mut expected_node, test_data)| { expected_node.sparse_vectors = Some( test_data .expected_vectors .iter() .cloned() .map(|d| { ( d.0, SparseEmbedding { indices: vec![0], values: d.1, }, ) }) .collect(), ); expected_node }) .collect(); let expected_embeddables_batch = test_data .clone() .iter() .flat_map(|d| &d.expected_embedables) .map(ToString::to_string) .collect::<Vec<String>>(); let expected_vectors_batch: SparseEmbeddings = test_data .clone() .iter() .flat_map(|d| { d.expected_vectors .iter() .map(|(_, v)| v) .cloned() .map(|v| SparseEmbedding { indices: vec![0], values: v, }) }) .collect(); let mut model_mock = MockSparseEmbeddingModel::new(); model_mock .expect_sparse_embed() .withf(move |embeddables| expected_embeddables_batch.eq(embeddables)) .times(1) .returning_st(move |_| Ok(expected_vectors_batch.clone())); let embed = SparseEmbed::new(model_mock); let mut stream = embed.batch_transform(test_nodes).await; for expected_node in expected_nodes { let ingested_node = stream .next() .await .expect("IngestionStream has same length as expected_nodes") .expect("Is OK"); debug_assert_eq!(ingested_node, expected_node); } } #[tokio::test] async fn test_returns_error_properly_if_sparse_embed_fails() { let test_nodes = vec![TextNode::new("chunk")]; let mut model_mock = MockSparseEmbeddingModel::new(); model_mock .expect_sparse_embed() .times(1) .returning(|_| Err(LanguageModelError::PermanentError("error".into()))); let embed = SparseEmbed::new(model_mock); let mut stream = embed.batch_transform(test_nodes).await; let error = stream .next() .await .expect("IngestionStream has same length as expected_nodes") .expect_err("Is Err"); assert_eq!(error.to_string(), "Permanent error: error"); } } ================================================ FILE: swiftide-integrations/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-integrations" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] swiftide-core = { path = "../swiftide-core", version = "0.32" } swiftide-macros = { path = "../swiftide-macros", version = "0.32" } anyhow = { workspace = true } async-trait = { workspace = true } derive_builder = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } base64 = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } itertools = { workspace = true } chrono = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } thiserror = { workspace = true } regex = { workspace = true } futures-util = { workspace = true } tera = { workspace = true } uuid = { workspace = true } metrics = { workspace = true, optional = true } tracing-futures = { version = "0.2.5", features = ["futures-03"] } schemars.workspace = true # Integrations async-openai = { workspace = true, optional = true, features = [ "rustls", "chat-completion", "embedding", "responses", ] } async-anthropic = { workspace = true, optional = true } qdrant-client = { workspace = true, optional = true, default-features = false, features = [ "serde", ] } sqlx = { workspace = true, optional = true, features = [ "any", "json", "macros", "postgres", "runtime-tokio", "chrono", "uuid", ] } pgvector = { workspace = true, optional = true, features = ["sqlx"] } redis = { workspace = true, features = [ "aio", "tokio-comp", "connection-manager", "tokio-rustls-comp", ], optional = true } tree-sitter = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-python = { workspace = true, optional = true } tree-sitter-ruby = { workspace = true, optional = true } tree-sitter-typescript = { workspace = true, optional = true } tree-sitter-javascript = { workspace = true, optional = true } tree-sitter-java = { workspace = true, optional = true } tree-sitter-go = { workspace = true, optional = true } tree-sitter-solidity = { workspace = true, optional = true } tree-sitter-c = { workspace = true, optional = true } tree-sitter-cpp = { workspace = true, optional = true } tree-sitter-c-sharp = { workspace = true, optional = true } tree-sitter-elixir = { workspace = true, optional = true } tree-sitter-html = { workspace = true, optional = true } tree-sitter-php = { workspace = true, optional = true } fastembed = { workspace = true, optional = true } spider = { workspace = true, optional = true, default-features = true } htmd = { workspace = true, optional = true } aws-config = { workspace = true, features = [ "behavior-version-latest", "credentials-login", "default-https-client", "rt-tokio", ], optional = true } aws-credential-types = { workspace = true, features = [ "hardcoded-credentials", ], optional = true } aws-sdk-bedrockruntime = { workspace = true, features = [ "behavior-version-latest", "default-https-client", "rt-tokio", ], optional = true } aws-smithy-types = { workspace = true, optional = true } aws-smithy-json = { version = "0.62.4", optional = true } secrecy = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } reqwest-eventsource = { workspace = true, optional = true } deadpool = { workspace = true, features = [ "managed", "rt_tokio_1", ], optional = true } fluvio = { workspace = true, optional = true } rdkafka = { workspace = true, optional = true } arrow-array = { version = "57.3", default-features = false, optional = true } lancedb = { workspace = true, optional = true } parquet = { workspace = true, optional = true, features = [ "async", "arrow", "snap", ] } redb = { workspace = true, optional = true } duckdb = { workspace = true, optional = true } libduckdb-sys = { workspace = true, optional = true } fs-err = { workspace = true, features = ["tokio"] } tiktoken-rs = { workspace = true, optional = true } [dev-dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } swiftide-test-utils = { path = "../swiftide-test-utils", features = [ "test-utils", ] } swiftide-macros = { path = "../swiftide-macros" } temp-dir = { workspace = true } pretty_assertions = { workspace = true } # arrow = { workspace = true, features = ["test_utils"] } duckdb = { workspace = true, features = ["bundled"] } libduckdb-sys = { workspace = true, features = [ "bundled", "vcpkg", "pkg-config", ] } # Used for hacking fluv to play nice flv-util = { workspace = true } mockall = { workspace = true } test-log = { workspace = true } testcontainers = { workspace = true } testcontainers-modules = { workspace = true, features = ["kafka"] } test-case = { workspace = true } indoc = { workspace = true } insta = { workspace = true } wiremock = { workspace = true } tokio-stream = { workspace = true } eventsource-stream.workspace = true aws-smithy-eventstream = "0.60.19" tracing-subscriber = { workspace = true } [features] default = ["rustls"] metrics = ["dep:metrics", "swiftide-core/metrics"] # Ensures rustls is used rustls = ["reqwest?/rustls", "fastembed?/hf-hub-native-tls"] # Qdrant for storage qdrant = ["dep:qdrant-client", "swiftide-core/qdrant", "chrono/now"] # PgVector for storage pgvector = ["dep:sqlx", "dep:pgvector"] # Redis for caching and storage redis = ["dep:redis"] # Tree-sitter for code operations and chunking tree-sitter = [ "dep:tree-sitter", "dep:tree-sitter-rust", "dep:tree-sitter-python", "dep:tree-sitter-ruby", "dep:tree-sitter-typescript", "dep:tree-sitter-javascript", "dep:tree-sitter-java", "dep:tree-sitter-go", "dep:tree-sitter-solidity", "dep:tree-sitter-c", "dep:tree-sitter-cpp", "dep:tree-sitter-c-sharp", "dep:tree-sitter-elixir", "dep:tree-sitter-html", "dep:tree-sitter-php", ] # OpenAI for embedding and prompting openai = [ "dep:async-openai", "tiktoken-rs?/async-openai", "dep:reqwest-eventsource", "dep:reqwest", "swiftide-core/openai", ] # Groq groq = ["dep:async-openai", "dep:secrecy", "dep:reqwest", "openai"] # Goolge Gemini gemini = ["dep:async-openai", "dep:secrecy", "dep:reqwest", "openai"] # Ollama prompting, embedding, chatcompletion ollama = ["dep:async-openai", "dep:secrecy", "dep:reqwest", "openai"] # Openrouter prompting, embedding, chatcompletion open-router = ["dep:async-openai", "dep:secrecy", "dep:reqwest", "openai"] # FastEmbed (by qdrant) for fast, local embeddings fastembed = [ "dep:fastembed", "fastembed/ort-download-binaries", "fastembed/hf-hub", ] # Dashscope prompting dashscope = ["dep:async-openai", "dep:secrecy", "dep:reqwest", "openai"] # Scraping via spider as loader and a html to markdown transformer scraping = ["dep:spider", "dep:htmd"] # AWS Bedrock for prompting aws-bedrock = [ "dep:aws-config", "dep:aws-credential-types", "dep:aws-sdk-bedrockruntime", "dep:aws-smithy-types", "dep:aws-smithy-json", ] lancedb = ["dep:lancedb", "dep:deadpool", "dep:arrow-array"] # Fluvio loader fluvio = ["dep:fluvio"] # Kafka loader kafka = ["dep:rdkafka"] # Paruqet loader parquet = ["dep:arrow-array", "dep:parquet"] # Anthropic for prompting and completions anthropic = ["dep:async-anthropic"] # Duckdb for indexing and retrieval duckdb = ["dep:duckdb", "dep:libduckdb-sys"] tiktoken = ["dep:tiktoken-rs"] # Langfuse compatibility langfuse = [] [lints] workspace = true [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: swiftide-integrations/src/anthropic/chat_completion.rs ================================================ use futures_util::{StreamExt as _, TryStreamExt as _, stream}; use std::sync::{Arc, Mutex}; use anyhow::{Context as _, Result}; use async_anthropic::types::{ CreateMessagesRequestBuilder, Message, MessageBuilder, MessageContent, MessageContentList, MessageRole, MessagesStreamEvent, ToolChoice, ToolResultBuilder, ToolUseBuilder, }; use async_trait::async_trait; use serde_json::{Value, json}; use swiftide_core::{ ChatCompletion, ChatCompletionStream, chat_completion::{ ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolOutput, ToolSpec, Usage, UsageBuilder, errors::LanguageModelError, }, }; use super::Anthropic; use super::tool_schema::AnthropicToolSchema; #[cfg(feature = "metrics")] use swiftide_core::metrics::emit_usage; #[async_trait] impl ChatCompletion for Anthropic { #[tracing::instrument(skip_all, err)] async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result<ChatCompletionResponse, LanguageModelError> { let model = &self.default_options.prompt_model; let request = self .build_request(request) .and_then(|b| b.build().map_err(LanguageModelError::permanent))?; tracing::debug!( model = &model, messages = serde_json::to_string_pretty(&request).expect("Infallible"), "[ChatCompletion] Request to anthropic" ); let response = self .client .messages() .create(request) .await .map_err(LanguageModelError::permanent)?; tracing::debug!( response = serde_json::to_string_pretty(&response).expect("Infallible"), "[ChatCompletion] Response from anthropic" ); let maybe_tool_calls = response .messages() .iter() .flat_map(Message::tool_uses) .map(|atool| { ToolCall::builder() .id(atool.id) .name(atool.name) .args(atool.input.to_string()) .build() .expect("infallible") }) .collect::<Vec<_>>(); let maybe_tool_calls = if maybe_tool_calls.is_empty() { None } else { Some(maybe_tool_calls) }; let mut builder = ChatCompletionResponse::builder() .maybe_message(response.messages().iter().find_map(Message::text)) .maybe_tool_calls(maybe_tool_calls) .to_owned(); if let Some(usage) = &response.usage { let input_tokens = usage.input_tokens.unwrap_or_default(); let output_tokens = usage.output_tokens.unwrap_or_default(); let total_tokens = input_tokens + output_tokens; #[cfg(feature = "metrics")] emit_usage( model, input_tokens.into(), output_tokens.into(), total_tokens.into(), self.metric_metadata.as_ref(), ); let usage = Usage { prompt_tokens: input_tokens, completion_tokens: output_tokens, total_tokens, details: None, }; if let Some(callback) = &self.on_usage { callback(&usage).await?; } let usage = UsageBuilder::default() .prompt_tokens(input_tokens) .completion_tokens(output_tokens) .total_tokens(total_tokens) .build() .map_err(LanguageModelError::permanent)?; builder.usage(usage); } builder.build().map_err(LanguageModelError::from) } #[tracing::instrument(skip_all)] async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { let model = &self.default_options.prompt_model; let request = match self .build_request(request) .and_then(|b| b.build().map_err(LanguageModelError::permanent)) { Ok(request) => request, Err(e) => { return e.into(); } }; tracing::debug!( model = &model, messages = serde_json::to_string_pretty(&request).expect("Infallible"), "[ChatCompletion] Request to anthropic" ); let response = self.client.messages().create_stream(request).await; let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default())); let final_response = Arc::clone(&accumulating_response); #[cfg(feature = "metrics")] let model = model.clone(); #[cfg(feature = "metrics")] let metric_metadata = self.metric_metadata.clone(); let maybe_usage_callback = self.on_usage.clone(); response .map_ok(move |chunk| { let accumulating_response = Arc::clone(&accumulating_response); let mut lock = accumulating_response.lock().unwrap(); append_delta_from_chunk(&chunk, &mut lock); lock.clone() }) .map_err(LanguageModelError::permanent) .chain( stream::iter(vec![final_response]).map(move |final_response| { if let Some(usage) = final_response.lock().unwrap().usage.as_ref() { let usage = usage.clone(); if let Some(callback) = maybe_usage_callback.as_ref() { let usage = usage.clone(); let callback = callback.clone(); tokio::spawn(async move { if let Err(e) = callback(&usage).await { tracing::error!("Error in on_usage callback: {}", e); } }); } #[cfg(feature = "metrics")] emit_usage( &model, usage.prompt_tokens.into(), usage.completion_tokens.into(), usage.total_tokens.into(), metric_metadata.as_ref(), ); } Ok(final_response.lock().unwrap().clone()) }), ) .boxed() } } #[allow(clippy::collapsible_match)] fn append_delta_from_chunk(chunk: &MessagesStreamEvent, lock: &mut ChatCompletionResponse) { match chunk { MessagesStreamEvent::ContentBlockStart { index, content_block, } => match content_block { MessageContent::ToolUse(tool_use) => { lock.append_tool_call_delta(*index, Some(&tool_use.id), Some(&tool_use.name), None); } MessageContent::Text(text) => { lock.append_message_delta(Some(&text.text)); } MessageContent::ToolResult(_tool_result) => (), }, MessagesStreamEvent::ContentBlockDelta { index, delta } => match delta { async_anthropic::types::ContentBlockDelta::TextDelta { text } => { lock.append_message_delta(Some(text)); } async_anthropic::types::ContentBlockDelta::InputJsonDelta { partial_json } => { lock.append_tool_call_delta(*index, None, None, Some(partial_json)); } }, #[allow(clippy::cast_possible_truncation)] MessagesStreamEvent::MessageDelta { usage, .. } => { if let Some(usage) = usage { let input_tokens = usage.input_tokens.unwrap_or_default(); let output_tokens = usage.output_tokens.unwrap_or_default(); let total_tokens = input_tokens + output_tokens; lock.append_usage_delta(input_tokens, output_tokens, total_tokens); } } MessagesStreamEvent::MessageStart { message, usage } => { if let Some(usage) = usage { let input_tokens = usage.input_tokens.unwrap_or_default(); let output_tokens = usage.output_tokens.unwrap_or_default(); let total_tokens = input_tokens + output_tokens; lock.append_usage_delta(input_tokens, output_tokens, total_tokens); } if let Some(message_usage) = &message.usage { let input_tokens = message_usage.input_tokens.unwrap_or_default(); let output_tokens = message_usage.output_tokens.unwrap_or_default(); let total_tokens = input_tokens + output_tokens; lock.append_usage_delta(input_tokens, output_tokens, total_tokens); } } _ => {} } } impl Anthropic { fn build_request( &self, request: &ChatCompletionRequest<'_>, ) -> Result<async_anthropic::types::CreateMessagesRequestBuilder, LanguageModelError> { let model = &self.default_options.prompt_model; let mut messages = request.messages().to_vec(); let maybe_system = messages .iter() .position(ChatMessage::is_system) .map(|idx| messages.remove(idx)); let messages = messages_to_antropic(&messages)?; let mut anthropic_request = CreateMessagesRequestBuilder::default() .model(model) .messages(messages) .to_owned(); if let Some(ChatMessage::System(system)) = maybe_system { anthropic_request.system(system); } if !request.tools_spec().is_empty() { anthropic_request .tools( request .tools_spec() .iter() .map(tools_to_anthropic) .collect::<Result<Vec<_>>>()?, ) .tool_choice(ToolChoice::Auto); } Ok(anthropic_request) } } fn messages_to_antropic(messages: &[ChatMessage]) -> Result<Vec<Message>> { let mut anthropic_messages = Vec::with_capacity(messages.len()); let mut messages = messages.iter().peekable(); while let Some(message) = messages.next() { match message { ChatMessage::ToolOutput(tool_call, tool_output) => { let mut content = vec![tool_result_to_anthropic(tool_call, tool_output)?]; while let Some(ChatMessage::ToolOutput(tool_call, tool_output)) = messages.peek() { content.push(tool_result_to_anthropic(tool_call, tool_output)?); messages.next(); } anthropic_messages.push( MessageBuilder::default() .role(MessageRole::User) .content(MessageContentList(content)) .build() .context("Failed to build message")?, ); } _ => { if let Some(message) = message_to_antropic(message)? { anthropic_messages.push(message); } } } } Ok(anthropic_messages) } fn tool_result_to_anthropic( tool_call: &ToolCall, tool_output: &ToolOutput, ) -> Result<MessageContent> { Ok(ToolResultBuilder::default() .tool_use_id(tool_call.id()) .content(tool_output.content().unwrap_or("Success")) .build()? .into()) } #[allow(clippy::items_after_statements)] fn message_to_antropic(message: &ChatMessage) -> Result<Option<Message>> { let mut builder = MessageBuilder::default().role(MessageRole::User).to_owned(); match message { ChatMessage::ToolOutput(tool_call, tool_output) => builder.content(MessageContentList( vec![tool_result_to_anthropic(tool_call, tool_output)?], )), ChatMessage::Summary(msg) | ChatMessage::System(msg) => builder.content(msg.as_str()), ChatMessage::User(content) => builder.content(content.as_str()), ChatMessage::UserWithParts(parts) => { if parts.iter().any(|part| { !matches!( part, swiftide_core::chat_completion::ChatMessageContentPart::Text { .. } ) }) { anyhow::bail!("Anthropic chat completions only support text message parts"); } let text_parts = parts .iter() .filter_map(|part| match part { swiftide_core::chat_completion::ChatMessageContentPart::Text { text } => { Some(text.as_ref()) } swiftide_core::chat_completion::ChatMessageContentPart::Image { .. } | swiftide_core::chat_completion::ChatMessageContentPart::Document { .. } | swiftide_core::chat_completion::ChatMessageContentPart::Audio { .. } | swiftide_core::chat_completion::ChatMessageContentPart::Video { .. } => None, }) .collect::<Vec<_>>(); builder.content(text_parts.join(" ")) } ChatMessage::Assistant(content, tool_calls) => { builder.role(MessageRole::Assistant); let mut content_list: Vec<MessageContent> = Vec::new(); if let Some(content) = content.as_ref() { content_list.push(content.clone().into()); } if let Some(tool_calls) = tool_calls.as_ref() { for tool_call in tool_calls { let tool_call = ToolUseBuilder::default() .id(tool_call.id()) .name(tool_call.name()) .input(tool_call.args().and_then(|v| v.parse::<Value>().ok())) .build()?; content_list.push(tool_call.into()); } } if content_list.is_empty() { return Ok(None); } let content_list = MessageContentList(content_list); builder.content(content_list) } ChatMessage::Reasoning(_) => return Ok(None), }; builder.build().context("Failed to build message").map(Some) } fn tools_to_anthropic( spec: &ToolSpec, ) -> Result<serde_json::value::Map<String, serde_json::Value>> { let mut map = json!({ "name": &spec.name, "description": &spec.description, }) .as_object_mut() .context("Failed to build tool")? .to_owned(); let schema = AnthropicToolSchema::try_from(spec) .context("tool schema must be Anthropic compatible")? .into_value(); map.insert("input_schema".to_string(), schema); Ok(map) } #[cfg(test)] mod tests { use super::*; use schemars::{JsonSchema, schema_for}; use swiftide_core::{ AgentContext, Tool, chat_completion::{ChatCompletionRequest, ChatMessage}, }; use wiremock::{ Mock, MockServer, ResponseTemplate, matchers::{body_partial_json, method, path}, }; #[derive(Clone)] struct FakeTool(); #[derive(Clone)] struct AlphaTool(); #[derive(JsonSchema, serde::Serialize, serde::Deserialize)] struct LocationArgs { location: String, } #[derive(JsonSchema, serde::Serialize, serde::Deserialize)] #[serde(deny_unknown_fields)] struct NestedCommentArgs { request: NestedCommentRequest, } #[derive(JsonSchema, serde::Serialize, serde::Deserialize)] #[serde(deny_unknown_fields)] struct NestedCommentRequest { #[serde(default, skip_serializing_if = "Option::is_none")] body: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] text: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] page_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] block_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] discussion_id: Option<String>, } #[async_trait] impl Tool for FakeTool { async fn invoke( &self, _agent_context: &dyn AgentContext, _tool_call: &ToolCall, ) -> std::result::Result< swiftide_core::chat_completion::ToolOutput, swiftide_core::chat_completion::errors::ToolError, > { todo!() } fn name(&self) -> std::borrow::Cow<'_, str> { "get_weather".into() } fn tool_spec(&self) -> ToolSpec { ToolSpec::builder() .description("Gets the weather") .name("get_weather") .parameters_schema(schema_for!(LocationArgs)) .build() .unwrap() } } #[async_trait] impl Tool for AlphaTool { async fn invoke( &self, _agent_context: &dyn AgentContext, _tool_call: &ToolCall, ) -> std::result::Result< swiftide_core::chat_completion::ToolOutput, swiftide_core::chat_completion::errors::ToolError, > { todo!() } fn name(&self) -> std::borrow::Cow<'_, str> { "alpha_tool".into() } fn tool_spec(&self) -> ToolSpec { ToolSpec::builder() .name("alpha_tool") .description("Alpha tool") .parameters_schema(schemars::schema_for!(LocationArgs)) .build() .unwrap() } } #[test_log::test(tokio::test)] async fn test_complete_without_tools() { // Start a wiremock server let mock_server = MockServer::start().await; // Create a mock response let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({ "content": [{"type": "text", "text": "mocked response"}] })); // Mock the expected endpoint Mock::given(method("POST")) .and(path("/v1/messages")) // Adjust path to match expected endpoint .respond_with(mock_response) .mount(&mock_server) .await; let client = async_anthropic::Client::builder() .base_url(mock_server.uri()) .build() .unwrap(); // Build an Anthropic client with the mock server's URL let mut client_builder = Anthropic::builder(); client_builder.client(client); let client = client_builder.build().unwrap(); // Prepare a sample request let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hello".into())]) .build() .unwrap(); // Call the complete method let result = client.complete(&request).await.unwrap(); // Assert the result assert_eq!(result.message, Some("mocked response".into())); assert!(result.tool_calls.is_none()); } #[test_log::test(tokio::test)] async fn test_complete_with_tools() { // Start a wiremock server let mock_server = MockServer::start().await; // Create a mock response let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({ "id": "msg_016zKNb88WhhgBQXhSaQf1rs", "content": [ { "type": "text", "text": "I'll check the current weather in San Francisco, CA for you." }, { "type": "tool_use", "id": "toolu_01E1yxpxXU4hBgCMLzPL1FuR", "input": { "location": "San Francisco, CA" }, "name": "get_weather" } ], "model": "claude-3-5-sonnet-20241022", "stop_reason": "tool_use", "stop_sequence": null, "usage": { "input_tokens": 403, "output_tokens": 71 } })); // Mock the expected endpoint Mock::given(method("POST")) .and(path("/v1/messages")) // Adjust path to match expected endpoint .respond_with(mock_response) .mount(&mock_server) .await; let client = async_anthropic::Client::builder() .base_url(mock_server.uri()) .build() .unwrap(); // Build an Anthropic client with the mock server's URL let mut client_builder = Anthropic::builder(); client_builder.client(client); let client = client_builder.build().unwrap(); // Prepare a sample request let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hello".into())]) .tool_specs([FakeTool().tool_spec()]) .build() .unwrap(); // Call the complete method let result = client.complete(&request).await.unwrap(); // Assert the result assert_eq!( result.message, Some("I'll check the current weather in San Francisco, CA for you.".into()) ); assert!(result.tool_calls.is_some()); let Some(tool_call) = result.tool_calls.and_then(|f| f.first().cloned()) else { panic!("No tool call found") }; assert_eq!(tool_call.name(), "get_weather"); assert_eq!( tool_call.args(), Some( json!({"location": "San Francisco, CA"}) .to_string() .as_str() ) ); } #[test] fn test_build_request_orders_tools_deterministically() { let client = Anthropic::builder().build().unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hello".into())]) .tool_specs([FakeTool().tool_spec(), AlphaTool().tool_spec()]) .build() .unwrap(); let built = client.build_request(&request).unwrap().build().unwrap(); let tool_names = built .tools .expect("tools present") .into_iter() .map(|tool| { tool.get("name") .and_then(serde_json::Value::as_str) .expect("tool name") .to_owned() }) .collect::<Vec<_>>(); assert_eq!(tool_names, vec!["alpha_tool", "get_weather"]); } #[test_log::test(tokio::test)] async fn test_complete_with_system_prompt() { // Start a wiremock server let mock_server = MockServer::start().await; // Create a mock response let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({ "content": [{"type": "text", "text": "Response with system prompt"}], "usage": { "input_tokens": 19, "output_tokens": 10, } })); // Mock the expected endpoint Mock::given(method("POST")) .and(path("/v1/messages")) // Adjust path to match expected endpoint .and(body_partial_json(json!({ "system": "System message", "messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}] }))) .respond_with(mock_response) .mount(&mock_server) .await; let client = async_anthropic::Client::builder() .base_url(mock_server.uri()) .build() .unwrap(); // Build an Anthropic client with the mock server's URL let mut client_builder = Anthropic::builder(); client_builder.client(client); let client = client_builder.build().unwrap(); // Prepare a sample request with a system message let request = ChatCompletionRequest::builder() .messages(vec![ ChatMessage::System("System message".into()), ChatMessage::User("Hello".into()), ]) .build() .unwrap(); // Call the complete method let response = client.complete(&request).await.unwrap(); // Assert the result assert_eq!(response.message, Some("Response with system prompt".into())); let usage = response.usage.unwrap(); assert_eq!(usage.prompt_tokens, 19); assert_eq!(usage.completion_tokens, 10); assert_eq!(usage.total_tokens, 29); } #[test] fn test_tools_to_anthropic() { let tool_spec = ToolSpec::builder() .description("Gets the weather") .name("get_weather") .parameters_schema(schema_for!(LocationArgs)) .build() .unwrap(); let result = tools_to_anthropic(&tool_spec).unwrap(); let expected_schema = tool_spec.strict_parameters_schema().unwrap().into_json(); let expected = json!({ "name": "get_weather", "description": "Gets the weather", "input_schema": expected_schema, }); assert_eq!(serde_json::Value::Object(result), expected); } #[test] fn test_tools_to_anthropic_preserves_optional_nested_fields() { let tool_spec = ToolSpec::builder() .description("Creates a comment") .name("create_comment") .parameters_schema(schema_for!(NestedCommentArgs)) .build() .unwrap(); let result = tools_to_anthropic(&tool_spec).unwrap(); let input_schema = result .get("input_schema") .and_then(Value::as_object) .expect("anthropic tool should contain input_schema"); assert_eq!( input_schema.get("type"), Some(&Value::String("object".into())) ); assert_eq!( input_schema.get("required"), Some(&Value::Array(vec![Value::String("request".into())])) ); let nested_ref = input_schema["properties"]["request"]["$ref"] .as_str() .expect("nested request should be referenced"); let nested_name = nested_ref .rsplit('/') .next() .expect("nested request ref name"); assert!(input_schema["$defs"][nested_name].get("required").is_none()); } #[test] fn test_build_request_groups_adjacent_tool_outputs() { let first_tool = ToolCall::builder() .id("tool_1") .name("shell_command") .args("{\"cmd\":\"pwd\"}") .build() .unwrap(); let second_tool = ToolCall::builder() .id("tool_2") .name("git") .args("{\"command\":\"status\"}") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ ChatMessage::Assistant(None, Some(vec![first_tool.clone(), second_tool.clone()])), ChatMessage::new_tool_output( first_tool, ToolOutput::Text("pwd output".to_string()), ), ChatMessage::new_tool_output( second_tool, ToolOutput::Text("git output".to_string()), ), ]) .build() .unwrap(); let client = Anthropic::builder().build().unwrap(); let built = client.build_request(&request).unwrap().build().unwrap(); assert_eq!(built.messages.len(), 2); assert_eq!(built.messages[1].content.len(), 2); } } ================================================ FILE: swiftide-integrations/src/anthropic/mod.rs ================================================ use std::{pin::Pin, sync::Arc}; use derive_builder::Builder; use swiftide_core::chat_completion::Usage; pub mod chat_completion; pub mod simple_prompt; mod tool_schema; #[derive(Builder, Clone)] pub struct Anthropic { #[builder( default = Arc::new(async_anthropic::Client::default()), setter(custom) )] client: Arc<async_anthropic::Client>, #[builder(default)] default_options: Options, #[cfg(feature = "metrics")] #[builder(default)] /// Optional metadata to attach to metrics emitted by this client. metric_metadata: Option<std::collections::HashMap<String, String>>, /// A callback function that is called when usage information is available. #[builder(default, setter(custom))] #[allow(clippy::type_complexity)] on_usage: Option< Arc< dyn for<'a> Fn( &'a Usage, ) -> Pin< Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>, > + Send + Sync, >, >, } impl std::fmt::Debug for Anthropic { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Anthropic") .field("client", &self.client) .field("default_options", &self.default_options) .finish() } } #[derive(Debug, Clone, Builder)] #[builder(setter(into, strip_option))] pub struct Options { #[builder(default)] pub prompt_model: String, } impl Default for Options { fn default() -> Self { Self { prompt_model: "claude-3-5-sonnet-20241022".to_string(), } } } impl Anthropic { pub fn builder() -> AnthropicBuilder { AnthropicBuilder::default() } } impl AnthropicBuilder { /// Adds a callback function that will be called when usage information is available. pub fn on_usage<F>(&mut self, func: F) -> &mut Self where F: Fn(&Usage) -> anyhow::Result<()> + Send + Sync + 'static, { let func = Arc::new(func); self.on_usage = Some(Some(Arc::new(move |usage: &Usage| { let func = func.clone(); Box::pin(async move { func(usage) }) }))); self } /// Adds an asynchronous callback function that will be called when usage information is /// available. pub fn on_usage_async<F>(&mut self, func: F) -> &mut Self where F: for<'a> Fn( &'a Usage, ) -> Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>> + Send + Sync + 'static, { let func = Arc::new(func); self.on_usage = Some(Some(Arc::new(move |usage: &Usage| { let func = func.clone(); Box::pin(async move { func(usage).await }) }))); self } /// Sets the client for the `Anthropic` instance. /// /// See the `async_anthropic::Client` documentation for more information. /// /// # Parameters /// - `client`: The `Anthropic` client to set. /// /// # Returns /// A mutable reference to the `AnthropicBuilder`. pub fn client(&mut self, client: async_anthropic::Client) -> &mut Self { self.client = Some(Arc::new(client)); self } /// Sets the default prompt model for the `Anthropic` instance. /// /// # Parameters /// - `model`: The prompt model to set. /// /// # Returns /// A mutable reference to the `AnthropicBuilder`. pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self { if let Some(options) = self.default_options.as_mut() { options.prompt_model = model.into(); } else { self.default_options = Some(Options { prompt_model: model.into(), }); } self } } ================================================ FILE: swiftide-integrations/src/anthropic/simple_prompt.rs ================================================ use anyhow::Context as _; use async_anthropic::{errors::AnthropicError, types::CreateMessagesRequestBuilder}; use async_trait::async_trait; use swiftide_core::{ chat_completion::{Usage, errors::LanguageModelError}, indexing::SimplePrompt, }; #[cfg(feature = "metrics")] use swiftide_core::metrics::emit_usage; use super::Anthropic; #[async_trait] impl SimplePrompt for Anthropic { #[tracing::instrument(skip_all, err)] async fn prompt( &self, prompt: swiftide_core::prompt::Prompt, ) -> Result<String, LanguageModelError> { let model = &self.default_options.prompt_model; let request = CreateMessagesRequestBuilder::default() .model(model) .messages(vec![prompt.render()?.into()]) .build() .map_err(LanguageModelError::permanent)?; tracing::debug!( model = &model, messages = serde_json::to_string_pretty(&request).map_err(LanguageModelError::permanent)?, "[SimplePrompt] Request to anthropic" ); let response = self.client.messages().create(request).await.map_err(|e| { match &e { AnthropicError::NetworkError(_) => LanguageModelError::TransientError(e.into()), // TODO: The Rust Anthropic client is not documented well, we should figure out // which of these errors are client errors and which are server errors. // And which would be the ContextLengthExceeded error // For now, we'll just map all of them to client errors so we get feedback. _ => LanguageModelError::PermanentError(e.into()), } })?; tracing::debug!( response = serde_json::to_string_pretty(&response).map_err(LanguageModelError::permanent)?, "[SimplePrompt] Response from anthropic" ); if let Some(usage) = response.usage.as_ref() { let usage = Usage { prompt_tokens: usage.input_tokens.unwrap_or_default(), completion_tokens: usage.output_tokens.unwrap_or_default(), total_tokens: (usage.input_tokens.unwrap_or_default() + usage.output_tokens.unwrap_or_default()), details: None, }; if let Some(callback) = &self.on_usage { callback(&usage).await?; } #[cfg(feature = "metrics")] { emit_usage( model, usage.prompt_tokens.into(), usage.completion_tokens.into(), usage.total_tokens.into(), self.metric_metadata.as_ref(), ); } } let message = response .messages() .into_iter() .next() .context("No messages in response") .map_err(LanguageModelError::permanent)?; message .text() .context("No text in response") .map_err(LanguageModelError::permanent) } } #[cfg(test)] mod tests { use wiremock::{ Mock, MockServer, ResponseTemplate, matchers::{method, path}, }; use super::*; #[tokio::test] async fn test_simple_prompt_with_mock() { // Start a WireMock server let mock_server = MockServer::start().await; // Create a mock response let mock_response = ResponseTemplate::new(200).set_body_json(serde_json::json!({ "content": [{"type": "text", "text": "mocked response"}] })); // Mock the expected endpoint Mock::given(method("POST")) .and(path("/v1/messages")) // Adjust path to match expected endpoint .respond_with(mock_response) .mount(&mock_server) .await; let client = async_anthropic::Client::builder() .base_url(mock_server.uri()) .build() .unwrap(); // Build an Anthropic client with the mock server's URL let mut client_builder = Anthropic::builder(); client_builder.client(client); let client = client_builder.build().unwrap(); // Call the prompt method let result = client.prompt("hello".into()).await.unwrap(); // Assert the result assert_eq!(result, "mocked response"); } } ================================================ FILE: swiftide-integrations/src/anthropic/tool_schema.rs ================================================ use serde_json::Value; use swiftide_core::chat_completion::{ToolSpec, ToolSpecError}; pub(super) struct AnthropicToolSchema(Value); impl AnthropicToolSchema { pub(super) fn into_value(self) -> Value { self.0 } } impl TryFrom<&ToolSpec> for AnthropicToolSchema { type Error = ToolSpecError; fn try_from(spec: &ToolSpec) -> Result<Self, Self::Error> { Ok(Self(spec.canonical_parameters_schema_json()?)) } } ================================================ FILE: swiftide-integrations/src/aws_bedrock_v2/chat_completion.rs ================================================ use std::collections::HashMap; use anyhow::Context as _; use async_trait::async_trait; use aws_sdk_bedrockruntime::{ operation::converse::ConverseOutput, types::{ AudioBlock, AudioFormat, AudioSource, AutoToolChoice, ContentBlock, ContentBlockDelta, ContentBlockStart, ConversationRole, ConverseOutput as ConverseResult, ConverseStreamOutput, DocumentBlock, DocumentFormat, DocumentSource, ImageBlock, ImageFormat, ImageSource, InferenceConfiguration, Message, ReasoningContentBlock, ReasoningContentBlockDelta, ReasoningTextBlock, S3Location, StopReason, SystemContentBlock, Tool, ToolChoice, ToolConfiguration, ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ToolSpecification, ToolUseBlock, VideoBlock, VideoFormat, VideoSource, }, }; use aws_smithy_json::{ deserialize::{json_token_iter, token::expect_document}, serialize::JsonValueWriter, }; use aws_smithy_types::{Blob, Document}; use base64::Engine as _; use futures_util::stream; #[cfg(feature = "langfuse")] use serde_json::json; use swiftide_core::{ ChatCompletion, ChatCompletionStream, chat_completion::{ ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatMessageContentPart, ChatMessageContentSource, ReasoningItem, ToolCall, ToolOutput, ToolSpec, errors::LanguageModelError, }, }; use tracing_futures::Instrument; use super::tool_schema::AwsBedrockToolSchema; use super::{AwsBedrock, Options}; type ConverseInputParts = ( Vec<Message>, Option<Vec<SystemContentBlock>>, Option<InferenceConfiguration>, Option<ToolConfiguration>, ); type ExtractedMessage = (Option<String>, Option<Vec<ToolCall>>, Vec<ReasoningItem>); #[async_trait] impl ChatCompletion for AwsBedrock { #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))] #[cfg_attr( feature = "langfuse", tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION")) )] async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result<ChatCompletionResponse, LanguageModelError> { let model = self.prompt_model()?; #[cfg(feature = "langfuse")] let tracking_request = Some(json!({ "model": model, "messages": request.messages(), "tools_spec": request.tools_spec(), })); #[cfg(not(feature = "langfuse"))] let tracking_request: Option<serde_json::Value> = None; let (messages, system, inference_config, tool_config) = match build_converse_input(request, &self.default_options) { Ok(parts) => parts, Err(error) => { Self::track_failure( model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return Err(error); } }; let additional_model_request_fields = match super::additional_model_request_fields_from_options(model, &self.default_options) { Ok(fields) => fields, Err(error) => { Self::track_failure( model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return Err(error); } }; tracing::debug!( model = model, inference_config = ?inference_config, has_tool_config = tool_config.is_some(), "[ChatCompletion] Request to bedrock converse" ); let response = match self .client .converse( model, messages, system, inference_config, tool_config, None, additional_model_request_fields, self.default_options .additional_model_response_field_paths .clone(), ) .await { Ok(response) => response, Err(error) => { Self::track_failure( model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return Err(error); } }; tracing::debug!(response = ?response, "[ChatCompletion] Response from bedrock converse"); let completion = match response_to_chat_completion(&response) { Ok(completion) => completion, Err(error) => { Self::track_failure( model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return Err(error); } }; if let Some(error) = super::context_length_exceeded_if_empty( completion.message.is_some(), completion.tool_calls.is_some(), completion .reasoning .as_ref() .is_some_and(|reasoning| !reasoning.is_empty()), Some(response.stop_reason()), ) { Self::track_failure(model, tracking_request.as_ref(), Some(&completion), &error); return Err(error); } self.track_completion( model, completion.usage.as_ref(), tracking_request.as_ref(), Some(&completion), ) .await?; Ok(completion) } #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all))] #[cfg_attr( feature = "langfuse", tracing::instrument(skip_all, fields(langfuse.type = "GENERATION")) )] async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { let model = match self.prompt_model() { Ok(model) => model.to_string(), Err(error) => return error.into(), }; #[cfg(feature = "langfuse")] let tracking_request = Some(json!({ "model": model, "messages": request.messages(), "tools_spec": request.tools_spec(), })); #[cfg(not(feature = "langfuse"))] let tracking_request: Option<serde_json::Value> = None; let (messages, system, inference_config, tool_config) = match build_converse_input(request, &self.default_options) { Ok(parts) => parts, Err(error) => { Self::track_failure( &model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return error.into(); } }; let additional_model_request_fields = match super::additional_model_request_fields_from_options(&model, &self.default_options) { Ok(fields) => fields, Err(error) => { Self::track_failure( &model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return error.into(); } }; let stream_output = match self .client .converse_stream( &model, messages, system, inference_config, tool_config, additional_model_request_fields, self.default_options .additional_model_response_field_paths .clone(), ) .await { Ok(stream_output) => stream_output, Err(error) => { Self::track_failure( &model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return error.into(); } }; let self_for_stream = self.clone(); let event_stream = stream_output.stream; let stream = stream::unfold( ( event_stream, ChatCompletionResponse::default(), None::<StopReason>, false, model, tracking_request, ), move |( mut event_stream, mut response, mut stop_reason, finished, model, tracking_request, )| { let self_for_stream = self_for_stream.clone(); async move { if finished { return None; } match event_stream.recv().await { Ok(Some(event)) => { apply_stream_event(&event, &mut response, &mut stop_reason); Some(( Ok(response.clone()), ( event_stream, response, stop_reason, false, model, tracking_request, ), )) } Ok(None) => { if let Some(error) = super::context_length_exceeded_if_empty( response.message.is_some(), response.tool_calls.is_some(), response .reasoning .as_ref() .is_some_and(|reasoning| !reasoning.is_empty()), stop_reason.as_ref(), ) { Self::track_failure( &model, tracking_request.as_ref(), Some(&response), &error, ); return Some(( Err(error), ( event_stream, response, stop_reason, true, model, tracking_request, ), )); } if let Err(error) = self_for_stream .track_completion( &model, response.usage.as_ref(), tracking_request.as_ref(), Some(&response), ) .await { return Some(( Err(error), ( event_stream, response, stop_reason, true, model, tracking_request, ), )); } Some(( Ok(response.clone()), ( event_stream, response, stop_reason, true, model, tracking_request, ), )) } Err(error) => { let error = super::converse_stream_output_error_to_language_model_error(error); Self::track_failure( &model, tracking_request.as_ref(), Some(&response), &error, ); Some(( Err(error), ( event_stream, response, stop_reason, true, model, tracking_request, ), )) } } } }, ); let span = if cfg!(feature = "langfuse") { tracing::info_span!("stream", langfuse.type = "GENERATION") } else { tracing::info_span!("stream") }; Box::pin(Instrument::instrument(stream, span)) } } fn build_converse_input( request: &ChatCompletionRequest<'_>, options: &Options, ) -> Result<ConverseInputParts, LanguageModelError> { let source_messages = request.messages(); let mut messages = Vec::with_capacity(source_messages.len()); let mut system = Vec::new(); let mut source_messages = source_messages.iter().peekable(); while let Some(message) = source_messages.next() { match message { ChatMessage::System(text) => { system.push(SystemContentBlock::Text(text.clone())); } ChatMessage::Summary(text) | ChatMessage::User(text) => { messages.push(user_message_from_text(text.clone())?); } ChatMessage::UserWithParts(parts) => messages.push(user_message_from_parts(parts)?), ChatMessage::Assistant(content, maybe_tool_calls) => { let mut blocks = Vec::with_capacity( usize::from(content.as_ref().is_some_and(|text| !text.is_empty())) + maybe_tool_calls.as_ref().map_or(0, Vec::len), ); if let Some(content) = content.as_ref() && !content.is_empty() { blocks.push(ContentBlock::Text(content.clone())); } if let Some(tool_calls) = maybe_tool_calls.as_ref() { for tool_call in tool_calls { let input = tool_call_args_to_document(tool_call.args()).with_context(|| { format!("Invalid JSON args for tool call {}", tool_call.name()) })?; let tool_use = ToolUseBlock::builder() .tool_use_id(tool_call.id()) .name(tool_call.name()) .input(input) .build() .map_err(LanguageModelError::permanent)?; blocks.push(ContentBlock::ToolUse(tool_use)); } } if !blocks.is_empty() { messages.push(message_from_blocks(ConversationRole::Assistant, blocks)?); } } ChatMessage::ToolOutput(tool_call, output) => { let mut blocks = vec![tool_output_to_result_block(tool_call, output)?]; while let Some(ChatMessage::ToolOutput(tool_call, output)) = source_messages.peek() { blocks.push(tool_output_to_result_block(tool_call, output)?); source_messages.next(); } messages.push(message_from_blocks(ConversationRole::User, blocks)?); } ChatMessage::Reasoning(item) => { if let Some(reasoning_message) = assistant_reasoning_message_from_item(item)? { messages.push(reasoning_message); } } } } if messages.is_empty() { return Err(LanguageModelError::permanent( "Bedrock Converse requires at least one non-system message", )); } Ok(( messages, (!system.is_empty()).then_some(system), super::inference_config_from_options(options), tool_config_from_specs(request.tools_spec().iter(), options.tool_strict_enabled())?, )) } fn user_message_from_text(text: String) -> Result<Message, LanguageModelError> { message_from_blocks(ConversationRole::User, vec![ContentBlock::Text(text)]) } fn user_message_from_parts( parts: &[ChatMessageContentPart], ) -> Result<Message, LanguageModelError> { let mut blocks = Vec::with_capacity(parts.len()); let mut has_text = false; let mut has_document = false; for part in parts { match part { ChatMessageContentPart::Text { text } => { if !text.is_empty() { blocks.push(ContentBlock::Text(text.clone())); has_text = true; } } ChatMessageContentPart::Image { source, format } => { blocks.push(ContentBlock::Image(image_block_from_part( source, format.as_deref(), )?)); } ChatMessageContentPart::Document { source, format, name, } => { blocks.push(ContentBlock::Document(document_block_from_part( source, format.as_deref(), name.as_deref(), )?)); has_document = true; } ChatMessageContentPart::Audio { source, format } => { blocks.push(ContentBlock::Audio(audio_block_from_part( source, format.as_deref(), )?)); } ChatMessageContentPart::Video { source, format } => { blocks.push(ContentBlock::Video(video_block_from_part( source, format.as_deref(), )?)); } } } if blocks.is_empty() { return Err(LanguageModelError::permanent( "UserWithParts requires at least one content part", )); } if has_document && !has_text { return Err(LanguageModelError::permanent( "Bedrock document parts require at least one text part in the same message", )); } message_from_blocks(ConversationRole::User, blocks) } fn image_block_from_part( source: &ChatMessageContentSource, format: Option<&str>, ) -> Result<ImageBlock, LanguageModelError> { let format = image_format_from_source(format, source)?; let source = image_source_from_content_source(source)?; ImageBlock::builder() .format(format) .source(source) .build() .map_err(LanguageModelError::permanent) } fn document_block_from_part( source: &ChatMessageContentSource, format: Option<&str>, name: Option<&str>, ) -> Result<DocumentBlock, LanguageModelError> { let format = document_format_from_source(format, source)?; let source = document_source_from_content_source(source)?; let name = name.unwrap_or("document"); DocumentBlock::builder() .format(format) .name(name) .source(source) .build() .map_err(LanguageModelError::permanent) } fn audio_block_from_part( source: &ChatMessageContentSource, format: Option<&str>, ) -> Result<AudioBlock, LanguageModelError> { let format = audio_format_from_source(format, source)?; let source = audio_source_from_content_source(source)?; AudioBlock::builder() .format(format) .source(source) .build() .map_err(LanguageModelError::permanent) } fn video_block_from_part( source: &ChatMessageContentSource, format: Option<&str>, ) -> Result<VideoBlock, LanguageModelError> { let format = video_format_from_source(format, source)?; let source = video_source_from_content_source(source)?; VideoBlock::builder() .format(format) .source(source) .build() .map_err(LanguageModelError::permanent) } fn image_source_from_content_source( source: &ChatMessageContentSource, ) -> Result<ImageSource, LanguageModelError> { source_from_content_source(source, "image", ImageSource::Bytes, ImageSource::S3Location) } fn document_source_from_content_source( source: &ChatMessageContentSource, ) -> Result<DocumentSource, LanguageModelError> { source_from_content_source( source, "document", DocumentSource::Bytes, DocumentSource::S3Location, ) } fn audio_source_from_content_source( source: &ChatMessageContentSource, ) -> Result<AudioSource, LanguageModelError> { source_from_content_source(source, "audio", AudioSource::Bytes, AudioSource::S3Location) } fn video_source_from_content_source( source: &ChatMessageContentSource, ) -> Result<VideoSource, LanguageModelError> { source_from_content_source(source, "video", VideoSource::Bytes, VideoSource::S3Location) } fn source_from_content_source<T>( source: &ChatMessageContentSource, label: &str, from_bytes: impl Fn(Blob) -> T, from_s3: impl Fn(S3Location) -> T, ) -> Result<T, LanguageModelError> { match source { ChatMessageContentSource::Bytes { data, .. } => Ok(from_bytes(Blob::new(data.clone()))), ChatMessageContentSource::S3 { uri, bucket_owner } => { Ok(from_s3(s3_location(uri, bucket_owner.as_deref())?)) } ChatMessageContentSource::Url { url } => { if is_s3_url(url) { Ok(from_s3(s3_location(url, None)?)) } else if let Some((_, encoded)) = parse_data_url(url) { Ok(from_bytes(Blob::new(decode_data_url_bytes(encoded)?))) } else { Err(LanguageModelError::permanent(format!( "Bedrock {label} source URL must be data: or s3://" ))) } } ChatMessageContentSource::FileId { .. } => Err(LanguageModelError::permanent(format!( "Bedrock does not support file_id {label} sources" ))), } } fn image_format_from_source( format: Option<&str>, source: &ChatMessageContentSource, ) -> Result<ImageFormat, LanguageModelError> { resolve_format( format, source, infer_image_format_from_source, |value| ImageFormat::try_parse(value).ok(), "image", ) } fn document_format_from_source( format: Option<&str>, source: &ChatMessageContentSource, ) -> Result<DocumentFormat, LanguageModelError> { resolve_format( format, source, infer_document_format_from_source, |value| DocumentFormat::try_parse(value).ok(), "document", ) } fn audio_format_from_source( format: Option<&str>, source: &ChatMessageContentSource, ) -> Result<AudioFormat, LanguageModelError> { resolve_format( format, source, infer_audio_format_from_source, |value| AudioFormat::try_parse(value).ok(), "audio", ) } fn video_format_from_source( format: Option<&str>, source: &ChatMessageContentSource, ) -> Result<VideoFormat, LanguageModelError> { resolve_format( format, source, infer_video_format_from_source, |value| VideoFormat::try_parse(value).ok(), "video", ) } fn resolve_format<T>( explicit_format: Option<&str>, source: &ChatMessageContentSource, infer: impl Fn(&ChatMessageContentSource) -> Option<&'static str>, parse: impl Fn(&str) -> Option<T>, label: &str, ) -> Result<T, LanguageModelError> { let value = explicit_format.or_else(|| infer(source)).ok_or_else(|| { LanguageModelError::permanent(format!("Bedrock {label} format is required")) })?; parse(value).ok_or_else(|| { LanguageModelError::permanent(format!("Unsupported Bedrock {label} format: {value}")) }) } fn infer_image_format_from_source(source: &ChatMessageContentSource) -> Option<&'static str> { infer_format_from_source( source, IMAGE_MEDIA_TYPE_FORMATS, IMAGE_EXTENSION_FORMATS, None, ) } fn infer_document_format_from_source(source: &ChatMessageContentSource) -> Option<&'static str> { infer_format_from_source( source, DOCUMENT_MEDIA_TYPE_FORMATS, DOCUMENT_EXTENSION_FORMATS, Some("txt"), ) } fn infer_audio_format_from_source(source: &ChatMessageContentSource) -> Option<&'static str> { infer_format_from_source( source, AUDIO_MEDIA_TYPE_FORMATS, AUDIO_EXTENSION_FORMATS, None, ) } fn infer_video_format_from_source(source: &ChatMessageContentSource) -> Option<&'static str> { infer_format_from_source( source, VIDEO_MEDIA_TYPE_FORMATS, VIDEO_EXTENSION_FORMATS, None, ) } fn infer_format_from_source( source: &ChatMessageContentSource, media_type_mappings: &[(&'static str, &'static str)], extension_mappings: &[(&'static str, &'static str)], fallback: Option<&'static str>, ) -> Option<&'static str> { match source { ChatMessageContentSource::Bytes { media_type, .. } => media_type .as_deref() .and_then(|media_type| mapped_format(media_type, media_type_mappings)) .or(fallback), ChatMessageContentSource::Url { url } => if let Some((media_type, _)) = parse_data_url(url) { mapped_format(media_type, media_type_mappings) } else { extension_from_url(url) .and_then(|extension| mapped_format(extension, extension_mappings)) } .or(fallback), ChatMessageContentSource::S3 { uri, .. } => extension_from_url(uri) .and_then(|extension| mapped_format(extension, extension_mappings)) .or(fallback), ChatMessageContentSource::FileId { .. } => fallback, } } fn s3_location(uri: &str, bucket_owner: Option<&str>) -> Result<S3Location, LanguageModelError> { let mut builder = S3Location::builder().uri(uri); if let Some(bucket_owner) = bucket_owner { builder = builder.bucket_owner(bucket_owner); } builder.build().map_err(LanguageModelError::permanent) } fn is_s3_url(url: &str) -> bool { url.starts_with("s3://") } fn parse_data_url(url: &str) -> Option<(&str, &str)> { let rest = url.strip_prefix("data:")?; let (header, data) = rest.split_once(',')?; let media_type = header.strip_suffix(";base64")?; Some((media_type, data)) } fn decode_data_url_bytes(encoded: &str) -> Result<Vec<u8>, LanguageModelError> { base64::engine::general_purpose::STANDARD .decode(encoded) .map_err(LanguageModelError::permanent) } fn extension_from_url(url: &str) -> Option<&str> { let without_query = url.split(['?', '#']).next()?; let filename = without_query.rsplit('/').next()?; let (_, extension) = filename.rsplit_once('.')?; Some(extension) } fn mapped_format(value: &str, mappings: &[(&'static str, &'static str)]) -> Option<&'static str> { mappings .iter() .find_map(|(input, output)| input.eq_ignore_ascii_case(value).then_some(*output)) } fn message_from_blocks( role: ConversationRole, blocks: Vec<ContentBlock>, ) -> Result<Message, LanguageModelError> { Message::builder() .role(role) .set_content(Some(blocks)) .build() .map_err(LanguageModelError::permanent) } fn tool_output_to_result_block( tool_call: &ToolCall, output: &ToolOutput, ) -> Result<ContentBlock, LanguageModelError> { let status = match output { ToolOutput::Fail(_) => Some(ToolResultStatus::Error), _ => Some(ToolResultStatus::Success), }; let tool_result = ToolResultBlock::builder() .tool_use_id(tool_call.id()) .content(tool_output_to_content_block(output)?) .set_status(status) .build() .map_err(LanguageModelError::permanent)?; Ok(ContentBlock::ToolResult(tool_result)) } fn tool_output_to_content_block( output: &ToolOutput, ) -> Result<ToolResultContentBlock, LanguageModelError> { match output { ToolOutput::Text(text) | ToolOutput::Fail(text) => { Ok(ToolResultContentBlock::Text(text.clone())) } ToolOutput::FeedbackRequired(Some(value)) | ToolOutput::Stop(Some(value)) | ToolOutput::AgentFailed(Some(value)) => { Ok(ToolResultContentBlock::Json(json_value_to_document(value)?)) } _ => Ok(ToolResultContentBlock::Text(output.to_string())), } } fn tool_call_args_to_document(args: Option<&str>) -> Result<Document, LanguageModelError> { match args.map(str::trim) { Some(args) if !args.is_empty() => parse_document_json_bytes(args.as_bytes()) .with_context(|| format!("Failed to parse tool args as JSON: {args}")) .map_err(LanguageModelError::permanent), _ => Ok(Document::Object(HashMap::new())), } } fn tool_config_from_specs<'a>( tool_specs: impl IntoIterator<Item = &'a ToolSpec>, strict: bool, ) -> Result<Option<ToolConfiguration>, LanguageModelError> { let tools = tool_specs .into_iter() .map(|spec| tool_spec_to_bedrock(spec, strict)) .collect::<Result<Vec<_>, _>>()?; if tools.is_empty() { return Ok(None); } let tool_config = ToolConfiguration::builder() .set_tools(Some(tools)) .tool_choice(ToolChoice::Auto(AutoToolChoice::builder().build())) .build() .map_err(LanguageModelError::permanent)?; Ok(Some(tool_config)) } fn tool_spec_to_bedrock(spec: &ToolSpec, strict: bool) -> Result<Tool, LanguageModelError> { let schema_value = AwsBedrockToolSchema::try_from(spec) .map(AwsBedrockToolSchema::into_value) .map_err(LanguageModelError::permanent)?; let input_schema = ToolInputSchema::Json(json_value_to_document(&schema_value)?); let mut builder = ToolSpecification::builder() .name(spec.name.clone()) .input_schema(input_schema) .strict(strict); if !spec.description.is_empty() { builder = builder.description(spec.description.clone()); } let tool_spec = builder.build().map_err(LanguageModelError::permanent)?; Ok(Tool::ToolSpec(tool_spec)) } pub(super) fn response_to_chat_completion( response: &ConverseOutput, ) -> Result<ChatCompletionResponse, LanguageModelError> { let (message, tool_calls, reasoning) = if let Some(ConverseResult::Message(message)) = response.output() { extract_message_and_tool_calls(message)? } else { (None, None, Vec::new()) }; let mut builder = ChatCompletionResponse::builder() .maybe_message(message) .maybe_tool_calls(tool_calls) .to_owned(); if !reasoning.is_empty() { builder.reasoning(reasoning); } if let Some(usage) = response.usage() { builder.usage(super::usage_from_bedrock(usage)); } builder.build().map_err(LanguageModelError::from) } fn extract_message_and_tool_calls( message: &Message, ) -> Result<ExtractedMessage, LanguageModelError> { let mut text = String::new(); let mut has_text = false; let mut tool_calls = Vec::with_capacity(message.content().len()); let mut reasoning = Vec::new(); for (content_block_index, block) in message.content().iter().enumerate() { match block { ContentBlock::Text(block_text) => { text.push_str(block_text); has_text = true; } ContentBlock::ToolUse(tool_use) => { let args = document_to_json_string(tool_use.input()); let tool_call = ToolCall::builder() .id(tool_use.tool_use_id()) .name(tool_use.name()) .args(args) .build() .map_err(LanguageModelError::permanent)?; tool_calls.push(tool_call); } ContentBlock::ReasoningContent(ReasoningContentBlock::ReasoningText( reasoning_text, )) => { reasoning.push(reasoning_item_from_reasoning_text( content_block_index, reasoning_text.text(), reasoning_text.signature(), )); } _ => {} } } let message = has_text.then_some(text); let tool_calls = (!tool_calls.is_empty()).then_some(tool_calls); Ok((message, tool_calls, reasoning)) } fn document_to_json_string(document: &Document) -> String { let mut output = String::new(); JsonValueWriter::new(&mut output).document(document); output } fn apply_stream_event( event: &ConverseStreamOutput, response: &mut ChatCompletionResponse, stop_reason: &mut Option<StopReason>, ) { match event { ConverseStreamOutput::ContentBlockStart(event) => { if let (Some(ContentBlockStart::ToolUse(tool_use)), Ok(index)) = (event.start(), usize::try_from(event.content_block_index())) { response.append_tool_call_delta( index, Some(tool_use.tool_use_id()), Some(tool_use.name()), None, ); } } ConverseStreamOutput::ContentBlockDelta(event) => { let Ok(index) = usize::try_from(event.content_block_index()) else { return; }; let Some(delta) = event.delta() else { return; }; match delta { ContentBlockDelta::Text(text) => { response.append_message_delta(Some(text)); } ContentBlockDelta::ToolUse(delta) => { response.append_tool_call_delta(index, None, None, Some(delta.input())); } ContentBlockDelta::ReasoningContent(delta) => { apply_reasoning_delta(response, index, delta); } _ => {} } } ConverseStreamOutput::MessageStop(event) => { *stop_reason = Some(event.stop_reason().clone()); } ConverseStreamOutput::Metadata(event) => { if let Some(usage) = event.usage() { response.usage = Some(super::usage_from_bedrock(usage)); } } _ => {} } } fn assistant_reasoning_message_from_item( item: &ReasoningItem, ) -> Result<Option<Message>, LanguageModelError> { let text = item .content .as_ref() .and_then(|content| content.first()) .map(String::as_str) .filter(|text| !text.is_empty()); let signature = item .encrypted_content .as_deref() .filter(|value| !value.is_empty()); let (Some(text), Some(signature)) = (text, signature) else { return Ok(None); }; let reasoning_text_block = ReasoningTextBlock::builder() .text(text) .signature(signature) .build() .map_err(LanguageModelError::permanent)?; message_from_blocks( ConversationRole::Assistant, vec![ContentBlock::ReasoningContent( ReasoningContentBlock::ReasoningText(reasoning_text_block), )], ) .map(Some) } fn reasoning_item_from_reasoning_text( content_block_index: usize, text: &str, signature: Option<&str>, ) -> ReasoningItem { ReasoningItem { id: format!("bedrock_reasoning_{content_block_index}"), summary: Vec::new(), content: Some(vec![text.to_string()]), encrypted_content: signature.map(ToString::to_string), status: None, } } fn apply_reasoning_delta( response: &mut ChatCompletionResponse, content_block_index: usize, delta: &ReasoningContentBlockDelta, ) { let reasoning_item = ensure_reasoning_item(response, content_block_index); match delta { ReasoningContentBlockDelta::Text(text) => { let content = reasoning_item .content .get_or_insert_with(|| vec![String::new()]); if content.is_empty() { content.push(String::new()); } content[0].push_str(text); } ReasoningContentBlockDelta::Signature(signature) => { reasoning_item.encrypted_content = Some(signature.clone()); } _ => {} } } fn ensure_reasoning_item( response: &mut ChatCompletionResponse, content_block_index: usize, ) -> &mut ReasoningItem { let reasoning = response.reasoning.get_or_insert_with(Vec::new); let reasoning_id = format!("bedrock_reasoning_{content_block_index}"); if let Some(position) = reasoning.iter().position(|item| item.id == reasoning_id) { return reasoning .get_mut(position) .expect("position from iter().position must exist"); } reasoning.push(ReasoningItem { id: reasoning_id, summary: Vec::new(), content: None, encrypted_content: None, status: None, }); reasoning .last_mut() .expect("pushed reasoning item must exist") } fn json_value_to_document(value: &serde_json::Value) -> Result<Document, LanguageModelError> { let bytes = serde_json::to_vec(value).map_err(LanguageModelError::permanent)?; parse_document_json_bytes(&bytes).map_err(LanguageModelError::permanent) } fn parse_document_json_bytes(input: &[u8]) -> anyhow::Result<Document> { let mut tokens = json_token_iter(input).peekable(); let document = expect_document(&mut tokens)?; if tokens.next().transpose()?.is_some() { anyhow::bail!("JSON input must contain exactly one value"); } Ok(document) } const IMAGE_MEDIA_TYPE_FORMATS: &[(&str, &str)] = &[ ("image/gif", "gif"), ("image/jpeg", "jpeg"), ("image/jpg", "jpeg"), ("image/png", "png"), ("image/webp", "webp"), ]; const IMAGE_EXTENSION_FORMATS: &[(&str, &str)] = &[ ("gif", "gif"), ("jpeg", "jpeg"), ("jpg", "jpeg"), ("png", "png"), ("webp", "webp"), ]; const DOCUMENT_MEDIA_TYPE_FORMATS: &[(&str, &str)] = &[ ("text/csv", "csv"), ("application/msword", "doc"), ( "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "docx", ), ("text/html", "html"), ("text/markdown", "md"), ("text/x-markdown", "md"), ("application/pdf", "pdf"), ("text/plain", "txt"), ("application/vnd.ms-excel", "xls"), ( "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "xlsx", ), ]; const DOCUMENT_EXTENSION_FORMATS: &[(&str, &str)] = &[ ("csv", "csv"), ("doc", "doc"), ("docx", "docx"), ("html", "html"), ("htm", "html"), ("md", "md"), ("markdown", "md"), ("pdf", "pdf"), ("txt", "txt"), ("xls", "xls"), ("xlsx", "xlsx"), ]; const AUDIO_MEDIA_TYPE_FORMATS: &[(&str, &str)] = &[ ("audio/aac", "aac"), ("audio/flac", "flac"), ("audio/m4a", "m4a"), ("audio/mka", "mka"), ("audio/x-matroska", "mkv"), ("audio/mpeg", "mp3"), ("audio/mp3", "mp3"), ("audio/mp4", "mp4"), ("audio/ogg", "ogg"), ("audio/opus", "opus"), ("audio/wav", "wav"), ("audio/x-wav", "wav"), ("audio/wave", "wav"), ("audio/webm", "webm"), ("audio/x-aac", "x-aac"), ]; const AUDIO_EXTENSION_FORMATS: &[(&str, &str)] = &[ ("aac", "aac"), ("flac", "flac"), ("m4a", "m4a"), ("mka", "mka"), ("mkv", "mkv"), ("mp3", "mp3"), ("mp4", "mp4"), ("mpeg", "mpeg"), ("mpga", "mpga"), ("ogg", "ogg"), ("opus", "opus"), ("pcm", "pcm"), ("wav", "wav"), ("webm", "webm"), ("x-aac", "x-aac"), ]; const VIDEO_MEDIA_TYPE_FORMATS: &[(&str, &str)] = &[ ("video/x-flv", "flv"), ("video/x-matroska", "mkv"), ("video/quicktime", "mov"), ("video/mp4", "mp4"), ("video/mpeg", "mpeg"), ("video/3gpp", "three_gp"), ("video/webm", "webm"), ("video/x-ms-wmv", "wmv"), ]; const VIDEO_EXTENSION_FORMATS: &[(&str, &str)] = &[ ("flv", "flv"), ("mkv", "mkv"), ("mov", "mov"), ("mp4", "mp4"), ("mpeg", "mpeg"), ("mpg", "mpg"), ("3gp", "three_gp"), ("webm", "webm"), ("wmv", "wmv"), ]; #[cfg(test)] mod tests { use aws_sdk_bedrockruntime::Client; use aws_sdk_bedrockruntime::{ operation::converse::ConverseOutput, types::{ ContentBlockDeltaEvent, ContentBlockStart, ContentBlockStartEvent, ConverseOutput as ConverseResult, Message, MessageStopEvent, ReasoningContentBlock, ReasoningContentBlockDelta, ReasoningTextBlock, StopReason, TokenUsage, ToolUseBlockDelta, ToolUseBlockStart, }, }; use futures_util::StreamExt as _; use schemars::{JsonSchema, schema_for}; use serde_json::{Value, json}; use swiftide_core::chat_completion::{ ChatMessage, ChatMessageContentPart, ChatMessageContentSource, ReasoningItem, ToolSpec, }; use wiremock::{ Mock, MockServer, Request, Respond, ResponseTemplate, matchers::{method, path}, }; use super::*; #[cfg(feature = "langfuse")] use crate::aws_bedrock_v2::test_utils::run_with_langfuse_event_capture; use crate::aws_bedrock_v2::{ AwsBedrock, MockBedrockConverse, ReasoningEffort, test_utils::{TEST_MODEL_ID, bedrock_client_for_mock_server, converse_stream_event}, }; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] struct WeatherArgs { location: String, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentArgs { request: NestedCommentRequest, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentRequest { #[serde(default, skip_serializing_if = "Option::is_none")] body: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] text: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] page_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] block_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] discussion_id: Option<String>, } fn response_with_text_and_tool_call() -> ConverseOutput { let mut args = HashMap::new(); args.insert( "location".to_string(), Document::String("Amsterdam".to_string()), ); ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::Text("Working on it".to_string())) .content(ContentBlock::ToolUse( ToolUseBlock::builder() .tool_use_id("call_1") .name("get_weather") .input(Document::Object(args)) .build() .unwrap(), )) .build() .unwrap(), )) .usage( TokenUsage::builder() .input_tokens(10) .output_tokens(8) .total_tokens(18) .build() .unwrap(), ) .stop_reason(StopReason::ToolUse) .build() .unwrap() } #[test_log::test(tokio::test)] async fn test_complete_maps_text_and_tool_calls() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .withf( |model_id, messages, system, inference_config, tool_config, output_config, _additional_model_request_fields, _additional_model_response_field_paths| { model_id == "anthropic.claude-3-5-sonnet-20241022-v2:0" && messages.len() == 1 && system.is_none() && inference_config.is_none() && tool_config.is_none() && output_config.is_none() }, ) .returning(|_, _, _, _, _, _, _, _| Ok(response_with_text_and_tool_call())); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Check weather".into())]) .build() .unwrap(); let response = bedrock.complete(&request).await.unwrap(); assert_eq!(response.message.as_deref(), Some("Working on it")); let tool_call = response .tool_calls .as_ref() .and_then(|calls| calls.first()) .expect("tool call"); assert_eq!(tool_call.id(), "call_1"); assert_eq!(tool_call.name(), "get_weather"); assert_eq!( serde_json::from_str::<serde_json::Value>(tool_call.args().unwrap()).unwrap(), serde_json::json!({"location":"Amsterdam"}) ); assert_eq!(response.usage.unwrap().total_tokens, 18); } #[cfg(feature = "langfuse")] #[test] fn test_complete_tracks_langfuse_failure_metadata_on_converse_error() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Err(LanguageModelError::permanent("bedrock request failed")) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Trace this failure".into())]) .build() .unwrap(); let (result, events) = run_with_langfuse_event_capture(|| async { bedrock.complete(&request).await }); let error = result.expect_err("request should fail"); assert!(error.to_string().contains("bedrock request failed")); let failure_event = events .iter() .find(|event| event.contains_key("langfuse.status_message")) .expect("langfuse failure event"); assert_eq!( failure_event .get("langfuse.model") .map(std::string::String::as_str), Some("anthropic.claude-3-5-sonnet-20241022-v2:0") ); assert!( failure_event .get("langfuse.input") .is_some_and(|input| input.contains("Trace this failure")) ); assert!( failure_event .get("langfuse.status_message") .is_some_and(|message| message.contains("bedrock request failed")) ); } #[test_log::test(tokio::test)] async fn test_complete_passes_additional_model_fields() { let mut bedrock_mock = MockBedrockConverse::new(); let mut thinking = HashMap::new(); thinking.insert("type".to_string(), Document::String("enabled".to_string())); thinking.insert("budget_tokens".to_string(), Document::from(512_u64)); let mut request_fields = HashMap::new(); request_fields.insert("thinking".to_string(), Document::Object(thinking)); let request_fields = Document::Object(request_fields); bedrock_mock .expect_converse() .once() .withf( |model_id, _, _, _, _, _, additional_model_request_fields, additional_model_response_field_paths| { model_id == "anthropic.claude-3-5-sonnet-20241022-v2:0" && additional_model_request_fields .as_ref() .is_some_and(|fields| { fields .as_object() .and_then(|map| map.get("thinking")) .and_then(Document::as_object) .and_then(|thinking| thinking.get("type")) .and_then(Document::as_string) == Some("enabled") }) && additional_model_response_field_paths .as_ref() .is_some_and(|paths| paths == &vec!["/thinking".to_string()]) }, ) .returning(|_, _, _, _, _, _, _, _| Ok(response_with_text_and_tool_call())); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .default_options(Options { additional_model_request_fields: Some(request_fields), additional_model_response_field_paths: Some(vec!["/thinking".to_string()]), ..Default::default() }) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Hello".into())]) .build() .unwrap(); let _ = bedrock.complete(&request).await.unwrap(); } #[test_log::test(tokio::test)] async fn test_complete_passes_reasoning_effort_for_claude_opus_4_5() { let mut bedrock_mock = MockBedrockConverse::new(); let mut thinking = HashMap::new(); thinking.insert("type".to_string(), Document::String("enabled".to_string())); thinking.insert("budget_tokens".to_string(), Document::from(512_u64)); let mut request_fields = HashMap::new(); request_fields.insert("thinking".to_string(), Document::Object(thinking)); let request_fields = Document::Object(request_fields); bedrock_mock .expect_converse() .once() .withf( |model_id, _, _, _, _, _, additional_model_request_fields, _additional_model_response_field_paths| { model_id == "anthropic.claude-opus-4-5-20251101-v1:0" && additional_model_request_fields .as_ref() .is_some_and(|fields| { let Some(fields) = fields.as_object() else { return false; }; let effort_matches = fields .get("output_config") .and_then(Document::as_object) .and_then(|output_config| output_config.get("effort")) .and_then(Document::as_string) == Some("medium"); let thinking_matches = fields .get("thinking") .and_then(Document::as_object) .and_then(|thinking| thinking.get("type")) .and_then(Document::as_string) == Some("enabled"); let beta_matches = fields .get("anthropic_beta") .and_then(Document::as_array) .is_some_and(|betas| { betas.iter().any(|beta| { beta.as_string() == Some("effort-2025-11-24") }) }); effort_matches && thinking_matches && beta_matches }) }, ) .returning(|_, _, _, _, _, _, _, _| Ok(response_with_text_and_tool_call())); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-opus-4-5-20251101-v1:0") .default_options(Options { reasoning_effort: Some(ReasoningEffort::Medium), additional_model_request_fields: Some(request_fields), ..Default::default() }) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Hello".into())]) .build() .unwrap(); let _ = bedrock.complete(&request).await.unwrap(); } #[test_log::test(tokio::test)] #[allow(deprecated)] async fn test_complete_respects_tool_strict_option() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .withf( |model_id, _, _, _, tool_config, output_config, _additional_model_request_fields, _additional_model_response_field_paths| { model_id == "anthropic.claude-3-5-sonnet-20241022-v2:0" && output_config.is_none() && tool_config .as_ref() .and_then(|config| config.tools().first()) .is_some_and(|tool| match tool { Tool::ToolSpec(spec) => spec.strict() == Some(false), _ => false, }) }, ) .returning(|_, _, _, _, _, _, _, _| Ok(response_with_text_and_tool_call())); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .default_options(Options { tool_strict: Some(false), ..Default::default() }) .build() .unwrap(); let tool_spec = ToolSpec::builder() .name("get_weather") .description("Get weather") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Check weather".into())]) .tools_spec([tool_spec]) .build() .unwrap(); let _ = bedrock.complete(&request).await.unwrap(); } #[test_log::test(tokio::test)] async fn test_complete_stream_requires_model() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock.expect_converse_stream().never(); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user("Hello")]) .build() .unwrap(); let mut stream = bedrock.complete_stream(&request).await; let first = stream.next().await.expect("stream should yield one item"); assert!(matches!(first, Err(LanguageModelError::PermanentError(_)))); assert!(stream.next().await.is_none()); } #[cfg(feature = "langfuse")] #[test] fn test_complete_stream_tracks_langfuse_failure_metadata_on_stream_error() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse_stream() .once() .returning(|_, _, _, _, _, _, _| { Err(LanguageModelError::transient("bedrock stream failed")) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user("Stream this failure")]) .build() .unwrap(); let (first_item, events) = run_with_langfuse_event_capture(|| async { let mut stream = bedrock.complete_stream(&request).await; stream.next().await.expect("stream should yield an error") }); let error = first_item.expect_err("stream should fail"); assert!(error.to_string().contains("bedrock stream failed")); let failure_event = events .iter() .find(|event| event.contains_key("langfuse.status_message")) .expect("langfuse failure event"); assert!( failure_event .get("langfuse.input") .is_some_and(|input| input.contains("Stream this failure")) ); assert!( failure_event .get("langfuse.status_message") .is_some_and(|message| message.contains("bedrock stream failed")) ); } #[test_log::test(tokio::test)] async fn test_complete_stream_rejects_system_only_messages() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock.expect_converse_stream().never(); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_system("You are a helper")]) .build() .unwrap(); let mut stream = bedrock.complete_stream(&request).await; let first = stream.next().await.expect("stream should yield one item"); assert!(matches!(first, Err(LanguageModelError::PermanentError(_)))); assert!(stream.next().await.is_none()); } #[test_log::test(tokio::test)] async fn test_complete_stream_returns_upstream_stream_error() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse_stream() .once() .withf( |model_id, messages, system, inference_config, tool_config, _additional_model_request_fields, _additional_model_response_field_paths| { model_id == "anthropic.claude-3-5-sonnet-20241022-v2:0" && messages.len() == 1 && matches!(messages[0].role(), ConversationRole::User) && matches!(messages[0].content().first(), Some(ContentBlock::Text(text)) if text == "Hello") && system.is_none() && inference_config.is_none() && tool_config.is_none() }, ) .returning(|_, _, _, _, _, _, _| { Err(LanguageModelError::transient(anyhow::anyhow!( "stream init failed" ))) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user("Hello")]) .build() .unwrap(); let mut stream = bedrock.complete_stream(&request).await; let first = stream.next().await.expect("stream should yield one item"); assert!(matches!(first, Err(LanguageModelError::TransientError(_)))); assert!(stream.next().await.is_none()); } #[test_log::test(tokio::test)] async fn test_complete_green_path_with_wiremock() { struct ValidateConverseRequest; impl Respond for ValidateConverseRequest { fn respond(&self, request: &Request) -> ResponseTemplate { let payload: Value = serde_json::from_slice(&request.body).expect("request json"); assert_eq!(payload["messages"][0]["role"], "user"); assert_eq!(payload["messages"][0]["content"][0]["text"], "Hello"); ResponseTemplate::new(200).set_body_json(json!({ "output": { "message": { "role": "assistant", "content": [ {"text": "Hello from bedrock"} ] } }, "stopReason": "end_turn", "usage": { "inputTokens": 2, "outputTokens": 5, "totalTokens": 7 } })) } } let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path(format!("/model/{TEST_MODEL_ID}/converse"))) .respond_with(ValidateConverseRequest) .mount(&mock_server) .await; let client: Client = bedrock_client_for_mock_server(&mock_server.uri()); let bedrock = AwsBedrock::builder() .client(client) .default_prompt_model(TEST_MODEL_ID) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user("Hello")]) .build() .unwrap(); let response = bedrock.complete(&request).await.unwrap(); assert_eq!(response.message(), Some("Hello from bedrock")); assert_eq!( response.usage.as_ref().map(|usage| usage.total_tokens), Some(7) ); } #[test_log::test(tokio::test)] async fn test_complete_stream_green_path_with_wiremock() { struct ValidateConverseStreamRequest { stream_body: Vec<u8>, } impl Respond for ValidateConverseStreamRequest { fn respond(&self, request: &Request) -> ResponseTemplate { let payload: Value = serde_json::from_slice(&request.body).expect("request json"); assert_eq!(payload["messages"][0]["role"], "user"); assert_eq!(payload["messages"][0]["content"][0]["text"], "Hello"); ResponseTemplate::new(200).set_body_raw( self.stream_body.clone(), "application/vnd.amazon.eventstream", ) } } let mock_server = MockServer::start().await; let stream_body = [ converse_stream_event( "contentBlockDelta", &json!({ "contentBlockIndex": 0, "delta": {"text": "Hello stream"} }), ), converse_stream_event( "metadata", &json!({ "usage": { "inputTokens": 4, "outputTokens": 5, "totalTokens": 9 } }), ), converse_stream_event( "messageStop", &json!({ "stopReason": "end_turn" }), ), ] .concat(); Mock::given(method("POST")) .and(path(format!("/model/{TEST_MODEL_ID}/converse-stream"))) .respond_with(ValidateConverseStreamRequest { stream_body }) .mount(&mock_server) .await; let client: Client = bedrock_client_for_mock_server(&mock_server.uri()); let bedrock = AwsBedrock::builder() .client(client) .default_prompt_model(TEST_MODEL_ID) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user("Hello")]) .build() .unwrap(); let responses = bedrock .complete_stream(&request) .await .collect::<Vec<_>>() .await; let last = responses .last() .expect("stream should yield") .as_ref() .expect("last response ok"); assert_eq!(last.message(), Some("Hello stream")); assert_eq!(last.usage.as_ref().map(|usage| usage.total_tokens), Some(9)); } #[test] fn test_tool_config_from_specs_builds_schema() { let tool_spec = ToolSpec::builder() .name("get_weather") .description("Get weather by location") .parameters_schema(schema_for!(WeatherArgs)) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs([tool_spec]) .build() .unwrap(); let tool_config = tool_config_from_specs(request.tools_spec().iter(), true) .unwrap() .expect("tool config"); assert_eq!(tool_config.tools().len(), 1); let Tool::ToolSpec(spec) = &tool_config.tools()[0] else { panic!("expected tool spec"); }; assert_eq!(spec.name(), "get_weather"); assert_eq!(spec.description(), Some("Get weather by location")); assert_eq!(spec.strict(), Some(true)); assert!(matches!( spec.input_schema(), Some(ToolInputSchema::Json(Document::Object(schema))) if schema.get("type") == Some(&Document::String("object".to_string())) && schema.get("additionalProperties") == Some(&Document::Bool(false)) )); } #[test] fn test_tool_config_from_specs_can_disable_strict() { let tool_spec = ToolSpec::builder() .name("get_weather") .description("Get weather") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs([tool_spec]) .build() .unwrap(); let tool_config = tool_config_from_specs(request.tools_spec().iter(), false) .unwrap() .expect("tool config"); let Tool::ToolSpec(spec) = &tool_config.tools()[0] else { panic!("expected tool spec"); }; assert_eq!(spec.strict(), Some(false)); assert!(matches!( spec.input_schema(), Some(ToolInputSchema::Json(Document::Object(schema))) if schema.get("type") == Some(&Document::String("object".to_string())) && schema.get("additionalProperties") == Some(&Document::Bool(false)) )); } #[test] fn test_tool_config_from_specs_does_not_apply_openai_required_workaround() { let tool_spec = ToolSpec::builder() .name("create_comment") .description("Create a comment") .parameters_schema(schema_for!(NestedCommentArgs)) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs([tool_spec]) .build() .unwrap(); let tool_config = tool_config_from_specs(request.tools_spec().iter(), true) .unwrap() .expect("tool config"); let Tool::ToolSpec(spec) = &tool_config.tools()[0] else { panic!("expected tool spec"); }; let Some(ToolInputSchema::Json(Document::Object(schema))) = spec.input_schema() else { panic!("expected JSON object schema"); }; assert_eq!( schema.get("type"), Some(&Document::String("object".to_string())) ); assert_eq!( schema.get("additionalProperties"), Some(&Document::Bool(false)) ); assert_eq!( schema.get("required"), Some(&Document::Array(vec![Document::String( "request".to_string() )])) ); let Some(Document::Object(properties)) = schema.get("properties") else { panic!("expected properties map"); }; let Some(Document::String(nested_ref)) = properties .get("request") .and_then(Document::as_object) .and_then(|request| request.get("$ref")) else { panic!("expected nested request $ref"); }; let nested_name = nested_ref .rsplit('/') .next() .expect("nested request ref name"); let Some(Document::Object(defs)) = schema.get("$defs") else { panic!("expected defs map"); }; let Some(Document::Object(nested_schema)) = defs.get(nested_name) else { panic!("expected nested request schema"); }; assert!(!nested_schema.contains_key("required")); } #[test] fn test_tool_config_from_specs_orders_tools_deterministically() { let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs([ ToolSpec::builder() .name("z_tool") .description("later") .build() .unwrap(), ToolSpec::builder() .name("a_tool") .description("earlier") .build() .unwrap(), ]) .build() .unwrap(); let tool_config = tool_config_from_specs(request.tools_spec().iter(), true) .unwrap() .expect("tool config"); let tool_names = tool_config .tools() .iter() .map(|tool| match tool { Tool::ToolSpec(spec) => spec.name(), _ => panic!("expected tool spec"), }) .collect::<Vec<_>>(); assert_eq!(tool_names, vec!["a_tool", "z_tool"]); } #[test] fn test_response_to_chat_completion_maps_reasoning_content() { let response = ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::ReasoningContent( ReasoningContentBlock::ReasoningText( ReasoningTextBlock::builder() .text("I should call a weather tool") .signature("sig_123") .build() .unwrap(), ), )) .content(ContentBlock::Text("Working on it".to_string())) .build() .unwrap(), )) .stop_reason(StopReason::EndTurn) .build() .unwrap(); let completion = response_to_chat_completion(&response).unwrap(); assert_eq!(completion.message.as_deref(), Some("Working on it")); let reasoning = completion.reasoning.expect("reasoning items"); assert_eq!(reasoning.len(), 1); assert_eq!(reasoning[0].id, "bedrock_reasoning_0"); assert_eq!( reasoning[0].content.as_ref().and_then(|c| c.first()), Some(&"I should call a weather tool".to_string()) ); assert_eq!(reasoning[0].encrypted_content.as_deref(), Some("sig_123")); } #[test] fn test_build_converse_input_replays_reasoning_items() { let request = ChatCompletionRequest::builder() .messages(vec![ ChatMessage::Reasoning(ReasoningItem { id: "r1".to_string(), summary: Vec::new(), content: Some(vec!["I should call a weather tool".to_string()]), encrypted_content: Some("sig_123".to_string()), status: None, }), ChatMessage::new_user("Use tool"), ]) .build() .unwrap(); let (messages, _system, _inference, _tool_config) = build_converse_input(&request, &Options::default()).unwrap(); assert_eq!(messages.len(), 2); assert!(matches!(messages[0].role(), ConversationRole::Assistant)); let reasoning = messages[0] .content() .first() .and_then(|content| content.as_reasoning_content().ok()) .and_then(|content| content.as_reasoning_text().ok()) .expect("reasoning content"); assert_eq!(reasoning.text(), "I should call a weather tool"); assert_eq!(reasoning.signature(), Some("sig_123")); } #[test] fn test_build_converse_input_groups_adjacent_tool_outputs() { let first_tool = ToolCall::builder() .id("tool_1") .name("shell_command") .args("{\"cmd\":\"pwd\"}") .build() .unwrap(); let second_tool = ToolCall::builder() .id("tool_2") .name("git") .args("{\"command\":\"status\"}") .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ ChatMessage::Assistant(None, Some(vec![first_tool.clone(), second_tool.clone()])), ChatMessage::new_tool_output( first_tool, ToolOutput::Text("pwd output".to_string()), ), ChatMessage::new_tool_output( second_tool, ToolOutput::Text("git output".to_string()), ), ]) .build() .unwrap(); let (messages, _system, _inference, _tool_config) = build_converse_input(&request, &Options::default()).unwrap(); assert_eq!(messages.len(), 2); assert!(matches!(messages[0].role(), ConversationRole::Assistant)); assert!(matches!(messages[1].role(), ConversationRole::User)); assert_eq!(messages[1].content().len(), 2); let first_result = messages[1] .content() .first() .and_then(|block| block.as_tool_result().ok()) .expect("first tool result"); let second_result = messages[1] .content() .get(1) .and_then(|block| block.as_tool_result().ok()) .expect("second tool result"); assert_eq!(first_result.tool_use_id(), "tool_1"); assert_eq!(second_result.tool_use_id(), "tool_2"); } #[test] fn test_build_converse_input_maps_image_part() { let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::text("Describe this image"), ChatMessageContentPart::image("data:image/png;base64,AA=="), ])]) .build() .unwrap(); let (messages, _system, _inference, _tool_config) = build_converse_input(&request, &Options::default()).unwrap(); assert_eq!(messages.len(), 1); assert!(matches!(messages[0].role(), ConversationRole::User)); assert_eq!(messages[0].content().len(), 2); let image = messages[0] .content() .get(1) .and_then(|content| content.as_image().ok()) .expect("image block"); assert!(matches!(image.format(), ImageFormat::Png)); assert!( image .source() .is_some_and(aws_sdk_bedrockruntime::types::ImageSource::is_bytes) ); } #[test] fn test_build_converse_input_maps_audio_part() { let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::text("Transcribe this"), ChatMessageContentPart::audio(ChatMessageContentSource::bytes( vec![1_u8, 2_u8, 3_u8], Some("audio/mpeg".to_string()), )), ])]) .build() .unwrap(); let (messages, _system, _inference, _tool_config) = build_converse_input(&request, &Options::default()).unwrap(); let audio = messages[0] .content() .get(1) .and_then(|content| content.as_audio().ok()) .expect("audio block"); assert!(matches!(audio.format(), AudioFormat::Mp3)); assert!( audio .source() .is_some_and(aws_sdk_bedrockruntime::types::AudioSource::is_bytes) ); } #[test] fn test_build_converse_input_maps_video_part() { let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::text("Describe this clip"), ChatMessageContentPart::video("s3://bucket/video.mp4"), ])]) .build() .unwrap(); let (messages, _system, _inference, _tool_config) = build_converse_input(&request, &Options::default()).unwrap(); let video = messages[0] .content() .get(1) .and_then(|content| content.as_video().ok()) .expect("video block"); assert!(matches!(video.format(), VideoFormat::Mp4)); assert!( video .source() .is_some_and(aws_sdk_bedrockruntime::types::VideoSource::is_s3_location) ); } #[test] fn test_build_converse_input_rejects_audio_http_url() { let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::text("Transcribe this"), ChatMessageContentPart::audio("https://example.com/audio.mp3"), ])]) .build() .unwrap(); let error = build_converse_input(&request, &Options::default()).unwrap_err(); assert!(format!("{error}").contains("audio source URL must be data: or s3://")); } #[test] fn test_build_converse_input_rejects_document_without_text() { let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::document(ChatMessageContentSource::bytes( vec![1_u8, 2_u8], Some("text/plain".to_string()), )), ])]) .build() .unwrap(); let error = build_converse_input(&request, &Options::default()).unwrap_err(); assert!(format!("{error}").contains("require at least one text part")); } #[test] #[allow(clippy::too_many_lines)] fn test_apply_stream_event_accumulates_deltas() { let mut response = ChatCompletionResponse::default(); let mut stop_reason = None; apply_stream_event( &ConverseStreamOutput::ContentBlockStart( ContentBlockStartEvent::builder() .content_block_index(0) .start(ContentBlockStart::ToolUse( ToolUseBlockStart::builder() .tool_use_id("call_1") .name("get_weather") .build() .unwrap(), )) .build() .unwrap(), ), &mut response, &mut stop_reason, ); apply_stream_event( &ConverseStreamOutput::ContentBlockDelta( ContentBlockDeltaEvent::builder() .content_block_index(0) .delta(ContentBlockDelta::ToolUse( ToolUseBlockDelta::builder() .input("{\"location\":\"Amsterdam\"}") .build() .unwrap(), )) .build() .unwrap(), ), &mut response, &mut stop_reason, ); apply_stream_event( &ConverseStreamOutput::ContentBlockDelta( ContentBlockDeltaEvent::builder() .content_block_index(1) .delta(ContentBlockDelta::Text("Tool call created".to_string())) .build() .unwrap(), ), &mut response, &mut stop_reason, ); apply_stream_event( &ConverseStreamOutput::ContentBlockDelta( ContentBlockDeltaEvent::builder() .content_block_index(2) .delta(ContentBlockDelta::ReasoningContent( ReasoningContentBlockDelta::Text("Thinking...".to_string()), )) .build() .unwrap(), ), &mut response, &mut stop_reason, ); apply_stream_event( &ConverseStreamOutput::ContentBlockDelta( ContentBlockDeltaEvent::builder() .content_block_index(2) .delta(ContentBlockDelta::ReasoningContent( ReasoningContentBlockDelta::Signature("sig_123".to_string()), )) .build() .unwrap(), ), &mut response, &mut stop_reason, ); apply_stream_event( &ConverseStreamOutput::Metadata( aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder() .usage( TokenUsage::builder() .input_tokens(5) .output_tokens(3) .total_tokens(8) .build() .unwrap(), ) .build(), ), &mut response, &mut stop_reason, ); apply_stream_event( &ConverseStreamOutput::MessageStop( MessageStopEvent::builder() .stop_reason(StopReason::ToolUse) .build() .unwrap(), ), &mut response, &mut stop_reason, ); assert_eq!(response.message.as_deref(), Some("Tool call created")); let tool_call = response .tool_calls .as_ref() .and_then(|calls| calls.first()) .expect("tool call"); assert_eq!(tool_call.id(), "call_1"); assert_eq!(tool_call.name(), "get_weather"); assert_eq!(tool_call.args(), Some("{\"location\":\"Amsterdam\"}")); let reasoning = response.reasoning.expect("reasoning item"); assert_eq!(reasoning.len(), 1); assert_eq!(reasoning[0].id, "bedrock_reasoning_2"); assert_eq!( reasoning[0].content.as_ref().and_then(|c| c.first()), Some(&"Thinking...".to_string()) ); assert_eq!(reasoning[0].encrypted_content.as_deref(), Some("sig_123")); assert_eq!(response.usage.unwrap().total_tokens, 8); assert!(matches!(stop_reason, Some(StopReason::ToolUse))); } } ================================================ FILE: swiftide-integrations/src/aws_bedrock_v2/mod.rs ================================================ use std::{pin::Pin, sync::Arc}; use async_trait::async_trait; use aws_sdk_bedrockruntime::{ Client, error::SdkError, operation::{ converse::{ConverseError, ConverseOutput}, converse_stream::{ ConverseStreamError, ConverseStreamOutput as BedrockConverseStreamOutput, }, }, types::{ InferenceConfiguration, Message, OutputConfig, StopReason, SystemContentBlock, TokenUsage, ToolConfiguration, error::ConverseStreamOutputError, }, }; use aws_smithy_types::Document; use derive_builder::Builder; use serde::Serialize; use swiftide_core::chat_completion::{ InputTokenDetails, Usage, UsageDetails, errors::LanguageModelError, }; use tokio::runtime::Handle; #[cfg(test)] use mockall::automock; mod chat_completion; mod simple_prompt; mod structured_prompt; #[cfg(test)] mod test_utils; mod tool_schema; /// Converse-based integration with AWS Bedrock. /// /// This integration uses Bedrock's unified Converse APIs (`Converse` + `ConverseStream`). #[derive(Builder, Clone)] #[builder(setter(into, strip_option))] pub struct AwsBedrock { /// The Bedrock runtime client. #[builder(default = self.default_client(), setter(custom))] client: Arc<dyn BedrockConverse>, /// Default options for prompt requests. #[builder(default, setter(custom))] default_options: Options, #[cfg(feature = "metrics")] #[builder(default)] /// Optional metadata to attach to metrics emitted by this client. metric_metadata: Option<std::collections::HashMap<String, String>>, /// A callback function that is called when usage information is available. #[builder(default, setter(custom))] #[allow(clippy::type_complexity)] on_usage: Option< Arc< dyn for<'a> Fn( &'a Usage, ) -> Pin< Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>, > + Send + Sync, >, >, } impl std::fmt::Debug for AwsBedrock { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AwsBedrock") .field("client", &self.client) .field("default_options", &self.default_options) .finish() } } /// Anthropic Claude effort guidance for Bedrock model-specific request fields. /// /// Bedrock currently documents the following support: /// - Claude Opus 4.5: `low`, `medium`, `high` via the `effort-2025-11-24` beta header. /// - Claude Opus 4.6 adaptive thinking: `low`, `medium`, `high`, `max` with no beta header. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] #[serde(rename_all = "lowercase")] pub enum ReasoningEffort { Low, Medium, High, Max, } impl ReasoningEffort { fn as_str(self) -> &'static str { match self { Self::Low => "low", Self::Medium => "medium", Self::High => "high", Self::Max => "max", } } } #[derive(Debug, Clone, Builder, Default)] #[builder(setter(strip_option))] pub struct Options { /// Model ID or ARN used as `modelId` in Converse requests. #[builder(default, setter(into))] pub prompt_model: Option<String>, /// Maximum number of tokens in the generated response. #[builder(default)] pub max_tokens: Option<i32>, /// Sampling temperature. #[builder(default)] pub temperature: Option<f32>, /// Nucleus sampling parameter. #[builder(default)] pub top_p: Option<f32>, /// Stop sequences for response generation. #[builder(default, setter(into))] pub stop_sequences: Option<Vec<String>>, /// Whether tool calls should enforce strict schema validation. /// /// Defaults to `true` when not set. #[builder(default)] pub tool_strict: Option<bool>, /// Anthropic beta headers forwarded through Bedrock model-specific request fields. /// /// This is useful for Anthropic features on Bedrock that require `anthropic_beta`. #[builder(default, setter(into))] pub anthropic_beta: Option<Vec<String>>, /// Anthropic Claude reasoning/token spend guidance forwarded through Bedrock /// `additional_model_request_fields.output_config.effort`. /// /// For Claude Opus 4.5, Bedrock requires the `effort-2025-11-24` beta header. Swiftide adds /// that header automatically when the configured model ID clearly identifies Claude Opus 4.5. /// If you route through an inference profile or ARN, also set `anthropic_beta` explicitly. /// /// For Claude Opus 4.6 adaptive thinking, Bedrock documents `max` in addition to the other /// levels. Use `additional_model_request_fields` to set `thinking.type = "adaptive"` when /// needed. #[builder(default)] pub reasoning_effort: Option<ReasoningEffort>, /// Provider-specific model request parameters passed to Converse. /// /// This is the Bedrock equivalent of model-specific reasoning controls. #[builder(default)] pub additional_model_request_fields: Option<Document>, /// JSON Pointer paths for model-specific response fields. #[builder(default, setter(into))] pub additional_model_response_field_paths: Option<Vec<String>>, } impl Options { pub fn builder() -> OptionsBuilder { OptionsBuilder::default() } pub fn tool_strict_enabled(&self) -> bool { self.tool_strict.unwrap_or(true) } pub fn merge(&mut self, other: Options) { if let Some(prompt_model) = other.prompt_model { self.prompt_model = Some(prompt_model); } if let Some(max_tokens) = other.max_tokens { self.max_tokens = Some(max_tokens); } if let Some(temperature) = other.temperature { self.temperature = Some(temperature); } if let Some(top_p) = other.top_p { self.top_p = Some(top_p); } if let Some(stop_sequences) = other.stop_sequences { self.stop_sequences = Some(stop_sequences); } if let Some(tool_strict) = other.tool_strict { self.tool_strict = Some(tool_strict); } if let Some(anthropic_beta) = other.anthropic_beta { self.anthropic_beta = Some(anthropic_beta); } if let Some(reasoning_effort) = other.reasoning_effort { self.reasoning_effort = Some(reasoning_effort); } if let Some(additional_model_request_fields) = other.additional_model_request_fields { self.additional_model_request_fields = Some(additional_model_request_fields); } if let Some(additional_model_response_field_paths) = other.additional_model_response_field_paths { self.additional_model_response_field_paths = Some(additional_model_response_field_paths); } } } impl AwsBedrock { pub fn builder() -> AwsBedrockBuilder { AwsBedrockBuilder::default() } /// Retrieve a reference to the default options. pub fn options(&self) -> &Options { &self.default_options } /// Retrieve a mutable reference to the default options. pub fn options_mut(&mut self) -> &mut Options { &mut self.default_options } fn prompt_model(&self) -> Result<&str, LanguageModelError> { self.default_options .prompt_model .as_deref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into())) } async fn report_usage(&self, model: &str, usage: &Usage) -> Result<(), LanguageModelError> { #[cfg(not(feature = "metrics"))] let _ = model; if let Some(callback) = &self.on_usage { callback(usage).await?; } #[cfg(feature = "metrics")] { swiftide_core::metrics::emit_usage( model, usage.prompt_tokens.into(), usage.completion_tokens.into(), usage.total_tokens.into(), self.metric_metadata.as_ref(), ); } Ok(()) } #[allow(unused_variables)] async fn track_completion<R, S>( &self, model: &str, usage: Option<&Usage>, request: Option<&R>, response: Option<&S>, ) -> Result<(), LanguageModelError> where R: Serialize + ?Sized, S: Serialize + ?Sized, { if let Some(usage) = usage { self.report_usage(model, usage).await?; } #[cfg(feature = "langfuse")] tracing::debug!( langfuse.model = model, langfuse.input = request.and_then(langfuse_json_redacted).unwrap_or_default(), langfuse.output = response.and_then(langfuse_json).unwrap_or_default(), langfuse.usage = usage.and_then(langfuse_json).unwrap_or_default(), ); Ok(()) } #[allow(unused_variables)] fn track_failure<R, S>( model: &str, request: Option<&R>, response: Option<&S>, error: &LanguageModelError, ) where R: Serialize + ?Sized, S: Serialize + ?Sized, { #[cfg(feature = "langfuse")] tracing::debug!( langfuse.model = model, langfuse.input = request.and_then(langfuse_json_redacted).unwrap_or_default(), langfuse.output = response.and_then(langfuse_json).unwrap_or_default(), langfuse.status_message = error.to_string(), ); } } impl AwsBedrockBuilder { #[allow(clippy::unused_self)] fn default_config(&self) -> aws_config::SdkConfig { tokio::task::block_in_place(|| Handle::current().block_on(aws_config::load_from_env())) } fn default_client(&self) -> Arc<Client> { Arc::new(Client::new(&self.default_config())) } /// Sets the Bedrock runtime client. pub fn client(&mut self, client: Client) -> &mut Self { self.client = Some(Arc::new(client)); self } /// Sets the default prompt model for Converse requests. pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self { if let Some(options) = self.default_options.as_mut() { options.prompt_model = Some(model.into()); } else { self.default_options = Some(Options { prompt_model: Some(model.into()), ..Default::default() }); } self } /// Sets default options for requests. /// /// Merges with existing options if already set. pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self { let options = options.into(); if let Some(existing_options) = self.default_options.as_mut() { existing_options.merge(options); } else { self.default_options = Some(options); } self } /// Adds a callback function that will be called when usage information is available. pub fn on_usage<F>(&mut self, func: F) -> &mut Self where F: Fn(&Usage) -> anyhow::Result<()> + Send + Sync + 'static, { let func = Arc::new(func); self.on_usage = Some(Some(Arc::new(move |usage: &Usage| { let func = func.clone(); Box::pin(async move { func(usage) }) }))); self } /// Adds an asynchronous callback function that will be called when usage information is /// available. pub fn on_usage_async<F>(&mut self, func: F) -> &mut Self where F: for<'a> Fn( &'a Usage, ) -> Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>> + Send + Sync + 'static, { let func = Arc::new(func); self.on_usage = Some(Some(Arc::new(move |usage: &Usage| { let func = func.clone(); Box::pin(async move { func(usage).await }) }))); self } #[cfg(test)] #[allow(private_bounds)] pub fn test_client(&mut self, client: impl BedrockConverse + 'static) -> &mut Self { self.client = Some(Arc::new(client)); self } } #[cfg_attr(test, automock)] #[async_trait] #[allow(clippy::too_many_arguments)] trait BedrockConverse: std::fmt::Debug + Send + Sync { async fn converse( &self, model_id: &str, messages: Vec<Message>, system: Option<Vec<SystemContentBlock>>, inference_config: Option<InferenceConfiguration>, tool_config: Option<ToolConfiguration>, output_config: Option<OutputConfig>, additional_model_request_fields: Option<Document>, additional_model_response_field_paths: Option<Vec<String>>, ) -> Result<ConverseOutput, LanguageModelError>; async fn converse_stream( &self, model_id: &str, messages: Vec<Message>, system: Option<Vec<SystemContentBlock>>, inference_config: Option<InferenceConfiguration>, tool_config: Option<ToolConfiguration>, additional_model_request_fields: Option<Document>, additional_model_response_field_paths: Option<Vec<String>>, ) -> Result<BedrockConverseStreamOutput, LanguageModelError>; } #[async_trait] #[allow(clippy::too_many_arguments)] impl BedrockConverse for Client { async fn converse( &self, model_id: &str, messages: Vec<Message>, system: Option<Vec<SystemContentBlock>>, inference_config: Option<InferenceConfiguration>, tool_config: Option<ToolConfiguration>, output_config: Option<OutputConfig>, additional_model_request_fields: Option<Document>, additional_model_response_field_paths: Option<Vec<String>>, ) -> Result<ConverseOutput, LanguageModelError> { let mut request = self .converse() .model_id(model_id) .set_messages(Some(messages)) .set_system(system) .set_tool_config(tool_config) .set_output_config(output_config) .set_additional_model_request_fields(additional_model_request_fields) .set_additional_model_response_field_paths(additional_model_response_field_paths); if let Some(inference_config) = inference_config { request = request.inference_config(inference_config); } request .send() .await .map_err(converse_error_to_language_model_error) } async fn converse_stream( &self, model_id: &str, messages: Vec<Message>, system: Option<Vec<SystemContentBlock>>, inference_config: Option<InferenceConfiguration>, tool_config: Option<ToolConfiguration>, additional_model_request_fields: Option<Document>, additional_model_response_field_paths: Option<Vec<String>>, ) -> Result<BedrockConverseStreamOutput, LanguageModelError> { let mut request = self .converse_stream() .model_id(model_id) .set_messages(Some(messages)) .set_system(system) .set_tool_config(tool_config) .set_additional_model_request_fields(additional_model_request_fields) .set_additional_model_response_field_paths(additional_model_response_field_paths); if let Some(inference_config) = inference_config { request = request.inference_config(inference_config); } request .send() .await .map_err(converse_stream_error_to_language_model_error) } } fn converse_error_to_language_model_error<R>( error: SdkError<ConverseError, R>, ) -> LanguageModelError where R: std::fmt::Debug + Send + Sync + 'static, { sdk_error_to_language_model_error(error, |service_error| { matches!( service_error, ConverseError::ThrottlingException(_) | ConverseError::ServiceUnavailableException(_) | ConverseError::ModelNotReadyException(_) | ConverseError::ModelTimeoutException(_) | ConverseError::InternalServerException(_) ) }) } fn converse_stream_error_to_language_model_error<R>( error: SdkError<ConverseStreamError, R>, ) -> LanguageModelError where R: std::fmt::Debug + Send + Sync + 'static, { sdk_error_to_language_model_error(error, |service_error| { matches!( service_error, ConverseStreamError::ThrottlingException(_) | ConverseStreamError::ServiceUnavailableException(_) | ConverseStreamError::ModelNotReadyException(_) | ConverseStreamError::ModelTimeoutException(_) | ConverseStreamError::InternalServerException(_) | ConverseStreamError::ModelStreamErrorException(_) ) }) } fn converse_stream_output_error_to_language_model_error<R>( error: SdkError<ConverseStreamOutputError, R>, ) -> LanguageModelError where R: std::fmt::Debug + Send + Sync + 'static, { sdk_error_to_language_model_error(error, |service_error| { matches!( service_error, ConverseStreamOutputError::ThrottlingException(_) | ConverseStreamOutputError::ServiceUnavailableException(_) | ConverseStreamOutputError::InternalServerException(_) | ConverseStreamOutputError::ModelStreamErrorException(_) ) }) } fn sdk_error_to_language_model_error<E, R>( error: SdkError<E, R>, is_transient_service_error: impl Fn(&E) -> bool, ) -> LanguageModelError where E: std::error::Error + Send + Sync + 'static, R: std::fmt::Debug + Send + Sync + 'static, { let is_transient = match &error { SdkError::TimeoutError(_) | SdkError::DispatchFailure(_) | SdkError::ResponseError(_) => { true } SdkError::ServiceError(service_error) => is_transient_service_error(service_error.err()), _ => false, }; let detailed_error = match error { SdkError::ServiceError(service_error) => anyhow::Error::new(service_error.into_err()), error => anyhow::Error::msg(error_chain_message(&error)), }; if is_transient { LanguageModelError::transient(detailed_error) } else { LanguageModelError::permanent(detailed_error) } } fn error_chain_message(error: &(dyn std::error::Error + 'static)) -> String { std::iter::successors(Some(error), |err| err.source()) .map(std::string::ToString::to_string) .collect::<Vec<_>>() .join(": ") } fn inference_config_from_options(options: &Options) -> Option<InferenceConfiguration> { let mut builder = InferenceConfiguration::builder(); let mut has_any_value = false; if let Some(max_tokens) = options.max_tokens { builder = builder.max_tokens(max_tokens); has_any_value = true; } if let Some(temperature) = options.temperature { builder = builder.temperature(temperature); has_any_value = true; } if let Some(top_p) = options.top_p { builder = builder.top_p(top_p); has_any_value = true; } if let Some(stop_sequences) = &options.stop_sequences { builder = builder.set_stop_sequences(Some(stop_sequences.clone())); has_any_value = true; } has_any_value.then(|| builder.build()) } fn additional_model_request_fields_from_options( model: &str, options: &Options, ) -> Result<Option<Document>, LanguageModelError> { if options.reasoning_effort.is_none() && options.anthropic_beta.is_none() { return Ok(options.additional_model_request_fields.clone()); } let mut fields = match options.additional_model_request_fields.clone() { Some(Document::Object(fields)) => fields, Some(_) => { return Err(LanguageModelError::permanent( "Bedrock additional_model_request_fields must be an object when using anthropic_beta or reasoning_effort", )); } None => std::collections::HashMap::new(), }; if let Some(reasoning_effort) = options.reasoning_effort { let mut output_config = match fields.remove("output_config") { Some(Document::Object(output_config)) => output_config, Some(_) => { return Err(LanguageModelError::permanent( "Bedrock additional_model_request_fields.output_config must be an object when using reasoning_effort", )); } None => std::collections::HashMap::new(), }; output_config.insert( "effort".to_string(), Document::String(reasoning_effort.as_str().to_string()), ); fields.insert("output_config".to_string(), Document::Object(output_config)); } let mut anthropic_beta = match fields.remove("anthropic_beta") { Some(Document::Array(items)) => items .into_iter() .map(document_string) .collect::<Result<Vec<_>, _>>()?, Some(_) => { return Err(LanguageModelError::permanent( "Bedrock additional_model_request_fields.anthropic_beta must be an array of strings", )); } None => Vec::new(), }; if let Some(extra_beta_headers) = &options.anthropic_beta { for beta_header in extra_beta_headers { push_unique_string(&mut anthropic_beta, beta_header.clone()); } } if options.reasoning_effort.is_some() && model.contains("claude-opus-4-5") { push_unique_string(&mut anthropic_beta, "effort-2025-11-24".to_string()); } if !anthropic_beta.is_empty() { fields.insert( "anthropic_beta".to_string(), Document::Array(anthropic_beta.into_iter().map(Document::String).collect()), ); } Ok(Some(Document::Object(fields))) } fn document_string(value: Document) -> Result<String, LanguageModelError> { match value { Document::String(value) => Ok(value), _ => Err(LanguageModelError::permanent( "Bedrock anthropic_beta entries must be strings", )), } } fn push_unique_string(values: &mut Vec<String>, value: String) { if !values.iter().any(|existing| existing == &value) { values.push(value); } } fn usage_from_bedrock(usage: &TokenUsage) -> Usage { let cached_tokens = usage .cache_read_input_tokens() .and_then(i32_to_u32) .or_else(|| usage.cache_write_input_tokens().and_then(i32_to_u32)); let details = cached_tokens.map(|cached_tokens| UsageDetails { input_tokens_details: Some(InputTokenDetails { cached_tokens: Some(cached_tokens), }), ..Default::default() }); Usage { prompt_tokens: i32_to_u32(usage.input_tokens()).unwrap_or_default(), completion_tokens: i32_to_u32(usage.output_tokens()).unwrap_or_default(), total_tokens: i32_to_u32(usage.total_tokens()).unwrap_or_default(), details, } } fn context_length_exceeded_if_empty( has_message: bool, has_tool_calls: bool, has_reasoning: bool, stop_reason: Option<&StopReason>, ) -> Option<LanguageModelError> { if has_message || has_tool_calls || has_reasoning || !matches!(stop_reason, Some(StopReason::ModelContextWindowExceeded)) { return None; } Some(LanguageModelError::context_length_exceeded( "Model context window exceeded", )) } fn i32_to_u32(value: i32) -> Option<u32> { u32::try_from(value).ok() } #[cfg(feature = "langfuse")] fn langfuse_json<T: Serialize + ?Sized>(value: &T) -> Option<String> { serde_json::to_string_pretty(value).ok() } #[cfg(feature = "langfuse")] fn langfuse_json_redacted<T: Serialize + ?Sized>(value: &T) -> Option<String> { let mut value = serde_json::to_value(value).ok()?; redact_sensitive_payloads(&mut value); serde_json::to_string_pretty(&value).ok() } #[cfg(feature = "langfuse")] fn redact_sensitive_payloads(value: &mut serde_json::Value) { match value { serde_json::Value::Object(map) => { for field in map.values_mut() { redact_sensitive_payloads(field); } } serde_json::Value::Array(items) => { if items.iter().all(|item| item.as_u64().is_some()) && items.len() > 64 { *value = serde_json::Value::String(format!("[{} bytes redacted]", items.len())); } else { for item in items { redact_sensitive_payloads(item); } } } serde_json::Value::String(text) => { if let Some(truncated) = truncate_data_url(text) { *text = truncated; } } _ => {} } } #[cfg(feature = "langfuse")] fn truncate_data_url(url: &str) -> Option<String> { const MAX_DATA_PREVIEW: usize = 32; if !url.starts_with("data:") { return None; } let (prefix, data) = url.split_once(',')?; if data.len() <= MAX_DATA_PREVIEW { return None; } let preview = &data[..MAX_DATA_PREVIEW]; let truncated = data.len() - MAX_DATA_PREVIEW; Some(format!( "{prefix},{preview}...[truncated {truncated} chars]" )) } #[cfg(test)] mod tests { use std::sync::{ Arc, atomic::{AtomicU32, Ordering}, }; use aws_sdk_bedrockruntime::{ error::{ConnectorError, SdkError}, operation::{converse::ConverseError, converse_stream::ConverseStreamError}, types::{ StopReason, TokenUsage, error::{ ConverseStreamOutputError, InternalServerException, ModelNotReadyException, ModelStreamErrorException, ServiceUnavailableException, ThrottlingException, ValidationException, }, }, }; use swiftide_core::chat_completion::errors::LanguageModelError; use super::*; fn usage(total_tokens: u32) -> Usage { Usage { prompt_tokens: total_tokens / 2, completion_tokens: total_tokens - (total_tokens / 2), total_tokens, details: None, } } #[test] fn test_options_builder_and_merge_only_overrides_present_fields() { let mut base = Options::builder() .prompt_model("model-a") .max_tokens(128) .temperature(0.1) .top_p(0.8) .stop_sequences(vec!["STOP_A".to_string()]) .tool_strict(false) .build() .unwrap(); let mut request_fields = std::collections::HashMap::new(); request_fields.insert("thinking".to_string(), Document::Bool(true)); let other = Options { prompt_model: Some("model-b".to_string()), max_tokens: None, temperature: Some(0.6), top_p: None, stop_sequences: Some(vec!["STOP_B".to_string()]), tool_strict: Some(true), anthropic_beta: Some(vec!["context-1m-2025-08-07".to_string()]), reasoning_effort: Some(ReasoningEffort::Medium), additional_model_request_fields: Some(Document::Object(request_fields)), additional_model_response_field_paths: Some(vec!["/thinking".to_string()]), }; base.merge(other); assert_eq!(base.prompt_model.as_deref(), Some("model-b")); assert_eq!(base.max_tokens, Some(128)); assert_eq!(base.temperature, Some(0.6)); assert_eq!(base.top_p, Some(0.8)); assert_eq!( base.stop_sequences.as_deref(), Some(&["STOP_B".to_string()][..]) ); assert_eq!(base.tool_strict, Some(true)); assert_eq!( base.anthropic_beta.as_deref(), Some(&["context-1m-2025-08-07".to_string()][..]) ); assert_eq!(base.reasoning_effort, Some(ReasoningEffort::Medium)); assert!(base.additional_model_request_fields.is_some()); assert_eq!( base.additional_model_response_field_paths.as_deref(), Some(&["/thinking".to_string()][..]) ); } #[test] fn test_tool_strict_enabled_defaults_to_true() { assert!(Options::default().tool_strict_enabled()); assert!( !Options { tool_strict: Some(false), ..Default::default() } .tool_strict_enabled() ); } #[test] fn test_builder_default_options_and_prompt_model_merge_branches() { let mut builder = AwsBedrock::builder(); builder.test_client(MockBedrockConverse::new()); builder.default_prompt_model("model-initial"); builder.default_prompt_model("model-final"); builder.default_options(Options { max_tokens: Some(64), ..Default::default() }); builder.default_options(Options { temperature: Some(0.7), ..Default::default() }); let mut client = builder.build().unwrap(); assert_eq!( client.options().prompt_model.as_deref(), Some("model-final") ); assert_eq!(client.options().max_tokens, Some(64)); assert_eq!(client.options().temperature, Some(0.7)); client.options_mut().top_p = Some(0.9); assert_eq!(client.options().top_p, Some(0.9)); assert!(format!("{client:?}").contains("AwsBedrock")); } #[test_log::test(tokio::test)] async fn test_track_completion_invokes_sync_usage_callback() { let observed = Arc::new(AtomicU32::new(0)); let observed_for_callback = observed.clone(); let mut builder = AwsBedrock::builder(); builder .test_client(MockBedrockConverse::new()) .default_prompt_model("model-a") .on_usage(move |usage| { observed_for_callback.store(usage.total_tokens, Ordering::Relaxed); Ok(()) }); let bedrock = builder.build().unwrap(); let req = serde_json::json!({"request": "value"}); let resp = serde_json::json!({"response": "value"}); let usage = usage(42); bedrock .track_completion("model-a", Some(&usage), Some(&req), Some(&resp)) .await .unwrap(); assert_eq!(observed.load(Ordering::Relaxed), 42); } #[test_log::test(tokio::test)] async fn test_track_completion_invokes_async_usage_callback() { let observed = Arc::new(AtomicU32::new(0)); let observed_for_callback = observed.clone(); let mut builder = AwsBedrock::builder(); builder .test_client(MockBedrockConverse::new()) .default_prompt_model("model-a") .on_usage_async(move |usage| { let observed_for_callback = observed_for_callback.clone(); Box::pin(async move { observed_for_callback.store(usage.total_tokens, Ordering::Relaxed); Ok(()) }) }); let bedrock = builder.build().unwrap(); let usage = usage(99); bedrock .track_completion( "model-a", Some(&usage), None::<&serde_json::Value>, None::<&serde_json::Value>, ) .await .unwrap(); assert_eq!(observed.load(Ordering::Relaxed), 99); } #[test] fn test_inference_config_from_options_builds_only_when_values_are_set() { assert!(inference_config_from_options(&Options::default()).is_none()); let options = Options { max_tokens: Some(256), temperature: Some(0.2), top_p: Some(0.9), stop_sequences: Some(vec!["DONE".to_string()]), ..Default::default() }; let config = inference_config_from_options(&options).expect("inference config"); assert_eq!(config.max_tokens(), Some(256)); assert_eq!(config.temperature(), Some(0.2)); assert_eq!(config.top_p(), Some(0.9)); assert_eq!(config.stop_sequences(), ["DONE"]); } #[test] fn test_additional_model_request_fields_merges_reasoning_effort_and_betas() { let mut thinking = std::collections::HashMap::new(); thinking.insert("type".to_string(), Document::String("enabled".to_string())); thinking.insert("budget_tokens".to_string(), Document::from(512_u64)); let raw_beta_headers = vec![Document::String("context-1m-2025-08-07".to_string())]; let mut additional_fields = std::collections::HashMap::new(); additional_fields.insert("thinking".to_string(), Document::Object(thinking)); additional_fields.insert( "anthropic_beta".to_string(), Document::Array(raw_beta_headers), ); let options = Options { anthropic_beta: Some(vec![ "interleaved-thinking-2025-05-14".to_string(), "effort-2025-11-24".to_string(), ]), reasoning_effort: Some(ReasoningEffort::Medium), additional_model_request_fields: Some(Document::Object(additional_fields)), ..Default::default() }; let merged = additional_model_request_fields_from_options( "anthropic.claude-opus-4-5-20251101-v1:0", &options, ) .unwrap() .expect("merged additional fields"); let fields = merged.as_object().expect("object fields"); let output_config = fields .get("output_config") .and_then(Document::as_object) .expect("output_config"); assert_eq!( output_config.get("effort").and_then(Document::as_string), Some("medium") ); let thinking = fields .get("thinking") .and_then(Document::as_object) .expect("thinking"); assert_eq!( thinking.get("type").and_then(Document::as_string), Some("enabled") ); assert!(thinking.get("budget_tokens").is_some()); let anthropic_beta = fields .get("anthropic_beta") .and_then(Document::as_array) .expect("anthropic_beta"); let anthropic_beta = anthropic_beta .iter() .map(|value| value.as_string().expect("beta header string")) .collect::<Vec<_>>(); assert_eq!( anthropic_beta, vec![ "context-1m-2025-08-07", "interleaved-thinking-2025-05-14", "effort-2025-11-24", ] ); } #[test] fn test_additional_model_request_fields_requires_object_when_merging_typed_fields() { let options = Options { reasoning_effort: Some(ReasoningEffort::Low), additional_model_request_fields: Some(Document::Bool(true)), ..Default::default() }; let error = additional_model_request_fields_from_options("model", &options).unwrap_err(); assert!( error .to_string() .contains("additional_model_request_fields must be an object") ); } #[test] fn test_usage_from_bedrock_prefers_cache_read_and_falls_back_to_cache_write() { let read_usage = TokenUsage::builder() .input_tokens(10) .output_tokens(5) .total_tokens(15) .cache_read_input_tokens(3) .cache_write_input_tokens(9) .build() .unwrap(); let mapped_read = usage_from_bedrock(&read_usage); assert_eq!( mapped_read .details .as_ref() .and_then(|details| details.input_tokens_details.as_ref()) .and_then(|details| details.cached_tokens), Some(3) ); let write_usage = TokenUsage::builder() .input_tokens(10) .output_tokens(5) .total_tokens(15) .cache_write_input_tokens(7) .build() .unwrap(); let mapped_write = usage_from_bedrock(&write_usage); assert_eq!( mapped_write .details .as_ref() .and_then(|details| details.input_tokens_details.as_ref()) .and_then(|details| details.cached_tokens), Some(7) ); } #[test] fn test_usage_from_bedrock_defaults_negative_counts_to_zero() { let usage = TokenUsage::builder() .input_tokens(-1) .output_tokens(-2) .total_tokens(-3) .build() .unwrap(); let mapped = usage_from_bedrock(&usage); assert_eq!(mapped.prompt_tokens, 0); assert_eq!(mapped.completion_tokens, 0); assert_eq!(mapped.total_tokens, 0); assert_eq!(i32_to_u32(-1), None); assert_eq!(i32_to_u32(12), Some(12)); } #[test] fn test_context_length_exceeded_only_when_empty_and_context_limit_hit() { assert!( context_length_exceeded_if_empty( false, false, false, Some(&StopReason::ModelContextWindowExceeded) ) .is_some() ); assert!(context_length_exceeded_if_empty(true, false, false, None).is_none()); assert!(context_length_exceeded_if_empty(false, true, false, None).is_none()); assert!(context_length_exceeded_if_empty(false, false, true, None).is_none()); assert!( context_length_exceeded_if_empty(false, false, false, Some(&StopReason::EndTurn)) .is_none() ); } #[test] fn test_sdk_error_mapping_classifies_transient_transport_failures() { let timeout = sdk_error_to_language_model_error::<ConverseError, ()>( SdkError::timeout_error("timeout"), |_| false, ); assert!(matches!(timeout, LanguageModelError::TransientError(_))); let dispatch = sdk_error_to_language_model_error::<ConverseError, ()>( SdkError::dispatch_failure(ConnectorError::other("dispatch".into(), None)), |_| false, ); assert!(matches!(dispatch, LanguageModelError::TransientError(_))); let response = sdk_error_to_language_model_error::<ConverseError, ()>( SdkError::response_error("response", ()), |_| false, ); assert!(matches!(response, LanguageModelError::TransientError(_))); let construction = sdk_error_to_language_model_error::<ConverseError, ()>( SdkError::construction_failure("construction"), |_| false, ); assert!(matches!( construction, LanguageModelError::PermanentError(_) )); } #[test] fn test_converse_error_mapping_distinguishes_transient_and_permanent_service_errors() { let throttled = converse_error_to_language_model_error::<()>(SdkError::service_error( ConverseError::ThrottlingException(ThrottlingException::builder().build()), (), )); assert!(matches!(throttled, LanguageModelError::TransientError(_))); let validation = converse_error_to_language_model_error::<()>(SdkError::service_error( ConverseError::ValidationException(ValidationException::builder().build()), (), )); assert!(matches!(validation, LanguageModelError::PermanentError(_))); } #[test] fn test_converse_stream_error_mapping_distinguishes_transient_and_permanent_service_errors() { let unavailable = converse_stream_error_to_language_model_error::<()>(SdkError::service_error( ConverseStreamError::ServiceUnavailableException( ServiceUnavailableException::builder().build(), ), (), )); assert!(matches!(unavailable, LanguageModelError::TransientError(_))); let validation = converse_stream_error_to_language_model_error::<()>(SdkError::service_error( ConverseStreamError::ValidationException(ValidationException::builder().build()), (), )); assert!(matches!(validation, LanguageModelError::PermanentError(_))); } #[test] fn test_converse_stream_output_error_mapping_distinguishes_transient_and_permanent_service_errors() { let transient = converse_stream_output_error_to_language_model_error::<()>(SdkError::service_error( ConverseStreamOutputError::ModelStreamErrorException( ModelStreamErrorException::builder().build(), ), (), )); assert!(matches!(transient, LanguageModelError::TransientError(_))); let permanent = converse_stream_output_error_to_language_model_error::<()>(SdkError::service_error( ConverseStreamOutputError::ValidationException( ValidationException::builder().build(), ), (), )); assert!(matches!(permanent, LanguageModelError::PermanentError(_))); } #[test] fn test_error_chain_message_collects_nested_sources() { let source = std::io::Error::other("inner"); let outer = std::io::Error::other(source); let chain = error_chain_message(&outer); assert!(chain.contains("inner")); } #[test] fn test_converse_error_mapping_model_not_ready_and_stream_internal_server_are_transient() { let model_not_ready = converse_error_to_language_model_error::<()>(SdkError::service_error( ConverseError::ModelNotReadyException(ModelNotReadyException::builder().build()), (), )); assert!(matches!( model_not_ready, LanguageModelError::TransientError(_) )); let stream_internal = converse_stream_output_error_to_language_model_error::<()>(SdkError::service_error( ConverseStreamOutputError::InternalServerException( InternalServerException::builder().build(), ), (), )); assert!(matches!( stream_internal, LanguageModelError::TransientError(_) )); } } ================================================ FILE: swiftide-integrations/src/aws_bedrock_v2/simple_prompt.rs ================================================ use async_trait::async_trait; use swiftide_core::{ ChatCompletion, chat_completion::{ChatCompletionRequest, ChatMessage, errors::LanguageModelError}, indexing::SimplePrompt, prompt::Prompt, }; #[cfg(test)] use crate::aws_bedrock_v2::Options; use super::AwsBedrock; #[async_trait] impl SimplePrompt for AwsBedrock { #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))] #[cfg_attr( feature = "langfuse", tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION")) )] async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> { let prompt_text = prompt.render()?; let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::new_user(prompt_text)]) .build() .map_err(LanguageModelError::permanent)?; let response = self.complete(&request).await?; response .message .ok_or_else(|| LanguageModelError::permanent("No text in response")) } } #[cfg(test)] mod tests { use std::collections::HashMap; use std::sync::{ Arc, atomic::{AtomicU32, Ordering}, }; use aws_sdk_bedrockruntime::Client; use aws_sdk_bedrockruntime::{ operation::converse::ConverseOutput, types::{ ContentBlock, ConversationRole, ConverseOutput as ConverseResult, Message, StopReason, TokenUsage, ToolUseBlock, }, }; use aws_smithy_types::Document; use serde_json::{Value, json}; use wiremock::{ Mock, MockServer, Request, Respond, ResponseTemplate, matchers::{method, path}, }; use super::*; use crate::aws_bedrock_v2::{ AwsBedrock, MockBedrockConverse, ReasoningEffort, test_utils::{TEST_MODEL_ID, bedrock_client_for_mock_server}, }; fn response_with_text(text: &str) -> ConverseOutput { ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::Text(text.to_string())) .build() .unwrap(), )) .stop_reason(StopReason::EndTurn) .build() .unwrap() } #[test_log::test(tokio::test)] async fn test_prompt_requires_model() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock.expect_converse().never(); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .build() .unwrap(); let error = bedrock.prompt("hello".into()).await.unwrap_err(); assert!(matches!(error, LanguageModelError::PermanentError(_))); } #[test_log::test(tokio::test)] async fn test_prompt_uses_converse_api_and_extracts_text() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .withf( |model_id, messages, system, inference_config, tool_config, output_config, _additional_model_request_fields, _additional_model_response_field_paths| { model_id == "anthropic.claude-3-5-sonnet-20241022-v2:0" && messages.len() == 1 && matches!(messages[0].role(), ConversationRole::User) && matches!(messages[0].content().first(), Some(ContentBlock::Text(text)) if text == "Hello") && system.is_none() && tool_config.is_none() && output_config.is_none() && inference_config .as_ref() .is_some_and(|config| { config.max_tokens() == Some(256) && config.temperature() == Some(0.4) && config.top_p() == Some(0.9) && config.stop_sequences() == ["STOP"] }) }, ) .returning(|_, _, _, _, _, _, _, _| Ok(response_with_text("Hello, world!"))); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .default_options(Options { max_tokens: Some(256), temperature: Some(0.4), top_p: Some(0.9), stop_sequences: Some(vec!["STOP".to_string()]), ..Default::default() }) .build() .unwrap(); let response = bedrock.prompt("Hello".into()).await.unwrap(); assert_eq!(response, "Hello, world!"); } #[test_log::test(tokio::test)] async fn test_prompt_maps_context_window_stop_reason() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .stop_reason(StopReason::ModelContextWindowExceeded) .build() .unwrap()) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let error = bedrock.prompt("Hello".into()).await.unwrap_err(); assert!(matches!( error, LanguageModelError::ContextLengthExceeded(_) )); } #[test_log::test(tokio::test)] async fn test_prompt_invokes_usage_callback() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::Text("ok".to_string())) .build() .unwrap(), )) .usage( TokenUsage::builder() .input_tokens(11) .output_tokens(7) .total_tokens(18) .cache_read_input_tokens(5) .build() .unwrap(), ) .stop_reason(StopReason::EndTurn) .build() .unwrap()) }); let observed_total = Arc::new(AtomicU32::new(0)); let observed_total_for_callback = observed_total.clone(); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .on_usage(move |usage| { observed_total_for_callback.store(usage.total_tokens, Ordering::Relaxed); assert_eq!(usage.prompt_tokens, 11); assert_eq!(usage.completion_tokens, 7); assert_eq!(usage.total_tokens, 18); assert_eq!( usage .details .as_ref() .and_then(|details| details.input_tokens_details.as_ref()) .and_then(|details| details.cached_tokens), Some(5) ); Ok(()) }) .build() .unwrap(); let response = bedrock.prompt("Hello".into()).await.unwrap(); assert_eq!(response, "ok"); assert_eq!(observed_total.load(Ordering::Relaxed), 18); } #[test_log::test(tokio::test)] async fn test_prompt_green_path_with_wiremock() { struct ValidateConverseRequest; impl Respond for ValidateConverseRequest { fn respond(&self, request: &Request) -> ResponseTemplate { let payload: Value = serde_json::from_slice(&request.body).expect("request json"); assert_eq!(payload["messages"][0]["role"], "user"); assert_eq!( payload["messages"][0]["content"][0]["text"], "hello from prompt" ); ResponseTemplate::new(200).set_body_json(json!({ "output": { "message": { "role": "assistant", "content": [ {"text": "prompt result"} ] } }, "stopReason": "end_turn", "usage": { "inputTokens": 1, "outputTokens": 2, "totalTokens": 3 } })) } } let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path(format!("/model/{TEST_MODEL_ID}/converse"))) .respond_with(ValidateConverseRequest) .mount(&mock_server) .await; let client: Client = bedrock_client_for_mock_server(&mock_server.uri()); let bedrock = AwsBedrock::builder() .client(client) .default_prompt_model(TEST_MODEL_ID) .build() .unwrap(); let response = bedrock.prompt("hello from prompt".into()).await.unwrap(); assert_eq!(response, "prompt result"); } #[test_log::test(tokio::test)] async fn test_prompt_returns_error_when_completion_has_no_text() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::ToolUse( ToolUseBlock::builder() .tool_use_id("call_1") .name("get_weather") .input(Document::Object(HashMap::new())) .build() .unwrap(), )) .build() .unwrap(), )) .stop_reason(StopReason::ToolUse) .build() .unwrap()) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let error = bedrock.prompt("Hello".into()).await.unwrap_err(); assert!(matches!(error, LanguageModelError::PermanentError(_))); assert!(error.to_string().contains("No text in response")); } #[ignore = "requires live AWS Bedrock access and billable model invocation"] #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn smoke_live_bedrock_reasoning_effort_prompt() { let model = std::env::var("SWIFTIDE_AWS_BEDROCK_LIVE_MODEL") .unwrap_or_else(|_| "anthropic.claude-opus-4-5-20251101-v1:0".to_string()); let prompt = "Reply with exactly 'swiftide-bedrock-effort-ok' and nothing else."; let bedrock = AwsBedrock::builder() .default_prompt_model(model.clone()) .default_options( Options::builder() .max_tokens(64) .reasoning_effort(ReasoningEffort::Low) .build() .unwrap(), ) .build() .unwrap(); let mut attempts = 0; let response = loop { attempts += 1; let attempt = tokio::time::timeout( std::time::Duration::from_secs(60), bedrock.prompt(prompt.into()), ) .await .expect("live Bedrock prompt timed out"); match attempt { Ok(response) => break response, Err(LanguageModelError::TransientError(error)) => { eprintln!("transient Bedrock error during live smoke test: {error}"); assert!( attempts < 3, "live Bedrock prompt failed after retries: {error}" ); tokio::time::sleep(std::time::Duration::from_secs(3)).await; } Err(error) => panic!("live Bedrock prompt failed: {error:?}"), } }; println!("model={model}"); println!("response={response}"); assert!( response.contains("swiftide-bedrock-effort-ok"), "unexpected response: {response}" ); } } ================================================ FILE: swiftide-integrations/src/aws_bedrock_v2/structured_prompt.rs ================================================ use async_trait::async_trait; use aws_sdk_bedrockruntime::types::{ ContentBlock, ConversationRole, JsonSchemaDefinition, Message, OutputConfig, OutputFormat, OutputFormatStructure, OutputFormatType, }; use schemars::Schema; #[cfg(feature = "langfuse")] use serde_json::json; use swiftide_core::{ DynStructuredPrompt, chat_completion::errors::LanguageModelError, prompt::Prompt, }; use super::AwsBedrock; #[async_trait] impl DynStructuredPrompt for AwsBedrock { #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))] #[cfg_attr( feature = "langfuse", tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION")) )] async fn structured_prompt_dyn( &self, prompt: Prompt, schema: Schema, ) -> Result<serde_json::Value, LanguageModelError> { let prompt_text = prompt.render()?; let model = self.prompt_model()?; let schema_json = serde_json::to_string(&schema).map_err(LanguageModelError::permanent)?; #[cfg(feature = "langfuse")] let tracking_request = Some(json!({ "model": model, "prompt": prompt_text.as_str(), "schema": schema, })); #[cfg(not(feature = "langfuse"))] let tracking_request: Option<serde_json::Value> = None; let message = Message::builder() .role(ConversationRole::User) .content(ContentBlock::Text(prompt_text)) .build() .map_err(LanguageModelError::permanent)?; let output_config = OutputConfig::builder() .text_format( OutputFormat::builder() .r#type(OutputFormatType::JsonSchema) .structure(OutputFormatStructure::JsonSchema( JsonSchemaDefinition::builder() .schema(schema_json) .name("structured_prompt") .build() .map_err(LanguageModelError::permanent)?, )) .build() .map_err(LanguageModelError::permanent)?, ) .build(); let additional_model_request_fields = super::additional_model_request_fields_from_options(model, &self.default_options)?; let response = match self .client .converse( model, vec![message], None, super::inference_config_from_options(&self.default_options), None, Some(output_config), additional_model_request_fields, self.default_options .additional_model_response_field_paths .clone(), ) .await { Ok(response) => response, Err(error) => { Self::track_failure( model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return Err(error); } }; let completion = match super::chat_completion::response_to_chat_completion(&response) { Ok(completion) => completion, Err(error) => { Self::track_failure( model, tracking_request.as_ref(), None::<&serde_json::Value>, &error, ); return Err(error); } }; self.track_completion( model, completion.usage.as_ref(), tracking_request.as_ref(), Some(&completion), ) .await?; let Some(ref response_text) = completion.message else { if let Some(error) = super::context_length_exceeded_if_empty( false, completion.tool_calls.is_some(), completion .reasoning .as_ref() .is_some_and(|reasoning| !reasoning.is_empty()), Some(response.stop_reason()), ) { Self::track_failure(model, tracking_request.as_ref(), Some(&completion), &error); return Err(error); } let error = LanguageModelError::permanent("No text in response"); Self::track_failure(model, tracking_request.as_ref(), Some(&completion), &error); return Err(error); }; serde_json::from_str(response_text.trim()) .map_err(|error| { LanguageModelError::permanent(anyhow::anyhow!( "Failed to parse model response as JSON: {error}" )) }) .inspect_err(|error| { Self::track_failure(model, tracking_request.as_ref(), Some(&completion), error); }) } } #[cfg(test)] mod tests { use std::collections::HashMap; use std::sync::{ Arc, atomic::{AtomicU32, Ordering}, }; use aws_sdk_bedrockruntime::Client; use aws_sdk_bedrockruntime::{ operation::converse::ConverseOutput, types::{ ContentBlock, ConversationRole, ConverseOutput as ConverseResult, Message, StopReason, TokenUsage, ToolUseBlock, }, }; use aws_smithy_types::Document; use schemars::{JsonSchema, schema_for}; use serde_json::{Value, json}; use wiremock::{ Mock, MockServer, Request, Respond, ResponseTemplate, matchers::{method, path}, }; use super::*; #[cfg(feature = "langfuse")] use crate::aws_bedrock_v2::test_utils::run_with_langfuse_event_capture; use crate::aws_bedrock_v2::{ AwsBedrock, MockBedrockConverse, Options, ReasoningEffort, test_utils::{TEST_MODEL_ID, bedrock_client_for_mock_server}, }; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema, PartialEq, Eq)] struct StructuredOutput { answer: String, } #[test_log::test(tokio::test)] async fn test_structured_prompt_parses_json_response() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .withf( |_, messages, _, _, _, output_config, _additional_model_request_fields, _additional_model_response_field_paths| { messages .first() .and_then(|message| message.content().first()) .and_then(|content| content.as_text().ok()) .is_some_and(|text| text == "What is two times twenty one?") && output_config .as_ref() .and_then(|config| config.text_format()) .is_some_and(|format| { matches!(format.r#type(), OutputFormatType::JsonSchema) && format .structure() .and_then(|structure| structure.as_json_schema().ok()) .is_some_and(|schema| { schema.schema().contains("\"answer\"") }) }) }, ) .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::Text("{\"answer\":\"42\"}".to_string())) .build() .unwrap(), )) .stop_reason(StopReason::EndTurn) .build() .unwrap()) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let value = bedrock .structured_prompt_dyn( "What is two times twenty one?".into(), schema_for!(StructuredOutput), ) .await .unwrap(); assert_eq!( serde_json::from_value::<StructuredOutput>(value).unwrap(), StructuredOutput { answer: "42".to_string() } ); } #[test_log::test(tokio::test)] async fn test_structured_prompt_passes_reasoning_effort_in_additional_model_request_fields() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .withf( |model_id, _messages, _system, _inference_config, _tool_config, output_config, additional_model_request_fields, _additional_model_response_field_paths| { model_id == "anthropic.claude-opus-4-5-20251101-v1:0" && output_config .as_ref() .and_then(|config| config.text_format()) .is_some() && additional_model_request_fields .as_ref() .is_some_and(|fields| { let Some(fields) = fields.as_object() else { return false; }; let effort_matches = fields .get("output_config") .and_then(Document::as_object) .and_then(|output_config| output_config.get("effort")) .and_then(Document::as_string) == Some("low"); let beta_matches = fields .get("anthropic_beta") .and_then(Document::as_array) .is_some_and(|betas| { betas.iter().any(|beta| { beta.as_string() == Some("effort-2025-11-24") }) }); effort_matches && beta_matches }) }, ) .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::Text("{\"answer\":\"42\"}".to_string())) .build() .unwrap(), )) .stop_reason(StopReason::EndTurn) .build() .unwrap()) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-opus-4-5-20251101-v1:0") .default_options(Options { reasoning_effort: Some(ReasoningEffort::Low), ..Default::default() }) .build() .unwrap(); let value = bedrock .structured_prompt_dyn( "What is two times twenty one?".into(), schema_for!(StructuredOutput), ) .await .unwrap(); assert_eq!( serde_json::from_value::<StructuredOutput>(value).unwrap(), StructuredOutput { answer: "42".to_string() } ); } #[cfg(feature = "langfuse")] #[test] fn test_structured_prompt_tracks_langfuse_failure_metadata_on_converse_error() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Err(LanguageModelError::permanent("structured prompt failed")) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let (result, events) = run_with_langfuse_event_capture(|| async { bedrock .structured_prompt_dyn( "Summarize this failure".into(), schema_for!(StructuredOutput), ) .await }); let error = result.expect_err("request should fail"); assert!(error.to_string().contains("structured prompt failed")); let failure_event = events .iter() .find(|event| event.contains_key("langfuse.status_message")) .expect("langfuse failure event"); assert_eq!( failure_event .get("langfuse.model") .map(std::string::String::as_str), Some("anthropic.claude-3-5-sonnet-20241022-v2:0") ); assert!( failure_event .get("langfuse.input") .is_some_and(|input| input.contains("Summarize this failure")) ); assert!( failure_event .get("langfuse.status_message") .is_some_and(|message| message.contains("structured prompt failed")) ); } #[test_log::test(tokio::test)] async fn test_structured_prompt_reports_usage() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::Text("{\"answer\":\"42\"}".to_string())) .build() .unwrap(), )) .usage( TokenUsage::builder() .input_tokens(9) .output_tokens(5) .total_tokens(14) .build() .unwrap(), ) .stop_reason(StopReason::EndTurn) .build() .unwrap()) }); let observed_total = Arc::new(AtomicU32::new(0)); let observed_total_for_callback = observed_total.clone(); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .on_usage(move |usage| { observed_total_for_callback.store(usage.total_tokens, Ordering::Relaxed); Ok(()) }) .build() .unwrap(); let _ = bedrock .structured_prompt_dyn( "What is two times twenty one?".into(), schema_for!(StructuredOutput), ) .await .unwrap(); assert_eq!(observed_total.load(Ordering::Relaxed), 14); } #[test_log::test(tokio::test)] async fn test_structured_prompt_green_path_with_wiremock() { struct ValidateStructuredConverseRequest; impl Respond for ValidateStructuredConverseRequest { fn respond(&self, request: &Request) -> ResponseTemplate { let payload: Value = serde_json::from_slice(&request.body).expect("request json"); assert_eq!(payload["messages"][0]["role"], "user"); assert_eq!( payload["messages"][0]["content"][0]["text"], "What is two times twenty one?" ); assert_eq!(payload["outputConfig"]["textFormat"]["type"], "json_schema"); assert_eq!( payload["outputConfig"]["textFormat"]["structure"]["jsonSchema"]["name"], "structured_prompt" ); let schema = payload["outputConfig"]["textFormat"]["structure"]["jsonSchema"]["schema"] .as_str() .expect("schema string"); assert!(schema.contains("\"answer\"")); ResponseTemplate::new(200).set_body_json(json!({ "output": { "message": { "role": "assistant", "content": [ {"text": "{\"answer\":\"42\"}"} ] } }, "stopReason": "end_turn", "usage": { "inputTokens": 2, "outputTokens": 3, "totalTokens": 5 } })) } } let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path(format!("/model/{TEST_MODEL_ID}/converse"))) .respond_with(ValidateStructuredConverseRequest) .mount(&mock_server) .await; let client: Client = bedrock_client_for_mock_server(&mock_server.uri()); let bedrock = AwsBedrock::builder() .client(client) .default_prompt_model(TEST_MODEL_ID) .build() .unwrap(); let value = bedrock .structured_prompt_dyn( "What is two times twenty one?".into(), schema_for!(StructuredOutput), ) .await .unwrap(); assert_eq!( serde_json::from_value::<StructuredOutput>(value).unwrap(), StructuredOutput { answer: "42".to_string() } ); } #[test_log::test(tokio::test)] async fn test_structured_prompt_returns_error_when_response_has_no_text() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::ToolUse( ToolUseBlock::builder() .tool_use_id("call_1") .name("structured_prompt") .input(Document::Object(HashMap::new())) .build() .unwrap(), )) .build() .unwrap(), )) .stop_reason(StopReason::ToolUse) .build() .unwrap()) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let error = bedrock .structured_prompt_dyn("Prompt".into(), schema_for!(StructuredOutput)) .await .unwrap_err(); assert!(matches!(error, LanguageModelError::PermanentError(_))); assert!(error.to_string().contains("No text in response")); } #[test_log::test(tokio::test)] async fn test_structured_prompt_returns_error_on_invalid_json_payload() { let mut bedrock_mock = MockBedrockConverse::new(); bedrock_mock .expect_converse() .once() .returning(|_, _, _, _, _, _, _, _| { Ok(ConverseOutput::builder() .output(ConverseResult::Message( Message::builder() .role(ConversationRole::Assistant) .content(ContentBlock::Text("not-json".to_string())) .build() .unwrap(), )) .stop_reason(StopReason::EndTurn) .build() .unwrap()) }); let bedrock = AwsBedrock::builder() .test_client(bedrock_mock) .default_prompt_model("anthropic.claude-3-5-sonnet-20241022-v2:0") .build() .unwrap(); let error = bedrock .structured_prompt_dyn("Prompt".into(), schema_for!(StructuredOutput)) .await .unwrap_err(); assert!(matches!(error, LanguageModelError::PermanentError(_))); assert!( error .to_string() .contains("Failed to parse model response as JSON") ); } } ================================================ FILE: swiftide-integrations/src/aws_bedrock_v2/test_utils.rs ================================================ use aws_credential_types::Credentials; use aws_sdk_bedrockruntime::{Client, Config, config::Region}; use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use serde_json::Value; pub(crate) const TEST_MODEL_ID: &str = "bedrock-test-model"; pub(crate) fn bedrock_client_for_mock_server(endpoint_url: &str) -> Client { let config = Config::builder() .behavior_version_latest() .region(Region::new("us-east-1")) .credentials_provider(Credentials::for_tests()) .endpoint_url(endpoint_url) .build(); Client::from_conf(config) } pub(crate) fn converse_stream_event(event_type: &str, payload: &Value) -> Vec<u8> { let message = Message::new_from_parts( vec![ Header::new(":message-type", HeaderValue::String("event".into())), Header::new( ":event-type", HeaderValue::String(event_type.to_owned().into()), ), Header::new( ":content-type", HeaderValue::String("application/json".into()), ), ], serde_json::to_vec(&payload).expect("serialize event payload"), ); let mut bytes = Vec::new(); aws_smithy_eventstream::frame::write_message_to(&message, &mut bytes) .expect("encode event stream frame"); bytes } #[cfg(feature = "langfuse")] pub(crate) type RecordedTracingEvent = std::collections::HashMap<String, String>; #[cfg(feature = "langfuse")] #[derive(Clone)] struct EventCaptureLayer { events: std::sync::Arc<std::sync::Mutex<Vec<RecordedTracingEvent>>>, } #[cfg(feature = "langfuse")] #[derive(Default)] struct EventFieldVisitor { fields: RecordedTracingEvent, } #[cfg(feature = "langfuse")] impl tracing::field::Visit for EventFieldVisitor { fn record_str(&mut self, field: &tracing::field::Field, value: &str) { self.fields .insert(field.name().to_string(), value.to_string()); } fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { self.fields .insert(field.name().to_string(), value.to_string()); } fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { self.fields .insert(field.name().to_string(), value.to_string()); } fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { self.fields .insert(field.name().to_string(), value.to_string()); } fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { self.fields .insert(field.name().to_string(), format!("{value:?}")); } } #[cfg(feature = "langfuse")] impl<S> tracing_subscriber::Layer<S> for EventCaptureLayer where S: tracing::Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>, { fn on_event( &self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>, ) { let mut visitor = EventFieldVisitor::default(); event.record(&mut visitor); self.events.lock().unwrap().push(visitor.fields); } } #[cfg(feature = "langfuse")] pub(crate) fn run_with_langfuse_event_capture<F, Fut, T>( future_factory: F, ) -> (T, Vec<RecordedTracingEvent>) where F: FnOnce() -> Fut, Fut: std::future::Future<Output = T>, { use tracing_subscriber::prelude::*; let events = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); let subscriber = tracing_subscriber::registry().with(EventCaptureLayer { events: events.clone(), }); let dispatch = tracing::Dispatch::new(subscriber); let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .expect("test runtime"); let result = tracing::dispatcher::with_default(&dispatch, || runtime.block_on(future_factory())); let recorded_events = events.lock().expect("event capture mutex").clone(); (result, recorded_events) } ================================================ FILE: swiftide-integrations/src/aws_bedrock_v2/tool_schema.rs ================================================ use serde_json::Value; use swiftide_core::chat_completion::{ToolSpec, ToolSpecError}; pub(super) struct AwsBedrockToolSchema(Value); impl AwsBedrockToolSchema { pub(super) fn into_value(self) -> Value { self.0 } } impl TryFrom<&ToolSpec> for AwsBedrockToolSchema { type Error = ToolSpecError; fn try_from(spec: &ToolSpec) -> Result<Self, Self::Error> { Ok(Self(spec.canonical_parameters_schema_json()?)) } } ================================================ FILE: swiftide-integrations/src/dashscope/config.rs ================================================ use reqwest::header::{AUTHORIZATION, HeaderMap}; use secrecy::{ExposeSecret as _, SecretString}; use serde::Deserialize; const DASHSCOPE_API_BASE: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1"; #[derive(Clone, Debug, Deserialize)] #[serde(default)] pub struct DashscopeConfig { api_base: String, api_key: SecretString, } impl Default for DashscopeConfig { fn default() -> Self { Self { api_base: DASHSCOPE_API_BASE.to_string(), api_key: get_api_key().into(), } } } fn get_api_key() -> String { std::env::var("QWEN_API_KEY") .unwrap_or_else(|_| std::env::var("DASHSCOPE_API_KEY").unwrap_or_default()) } impl async_openai::config::Config for DashscopeConfig { fn headers(&self) -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert( AUTHORIZATION, format!("Bearer {}", self.api_key.expose_secret()) .as_str() .parse() .unwrap(), ); headers } fn url(&self, path: &str) -> String { format!("{}{}", self.api_base, path) } fn api_base(&self) -> &str { &self.api_base } fn api_key(&self) -> &SecretString { &self.api_key } fn query(&self) -> Vec<(&str, &str)> { vec![] } } ================================================ FILE: swiftide-integrations/src/dashscope/mod.rs ================================================ use config::DashscopeConfig; use crate::openai; mod config; pub type Dashscope = openai::GenericOpenAI<DashscopeConfig>; impl Dashscope { pub fn builder() -> DashscopeBuilder { DashscopeBuilder::default() } } pub type DashscopeBuilder = openai::GenericOpenAIBuilder<DashscopeConfig>; pub type DashscopeBuilderError = openai::GenericOpenAIBuilderError; pub use openai::{Options, OptionsBuilder, OptionsBuilderError}; impl Default for Dashscope { fn default() -> Self { Dashscope::builder().build().unwrap() } } #[cfg(test)] mod test { use super::*; #[test] fn test_default_prompt_model() { let openai = Dashscope::builder() .default_prompt_model("qwen-long") .build() .unwrap(); assert_eq!( openai.default_options.prompt_model, Some("qwen-long".to_string()) ); let openai = Dashscope::builder() .default_prompt_model("qwen-turbo") .build() .unwrap(); assert_eq!( openai.default_options.prompt_model, Some("qwen-turbo".to_string()) ); } } ================================================ FILE: swiftide-integrations/src/duckdb/extensions.sql ================================================ INSTALL vss; INSTALL fts; ================================================ FILE: swiftide-integrations/src/duckdb/hybrid_query.sql ================================================ with fts as ( select uuid, chunk, path, fts_main_{{table_name}}.match_bm25( uuid, {{query}}, fields := chunk ) as score from {{table_name}} limit {{top_n}} ), embd as ( select uuid, chunk, path, array_cosine_similarity({{embedding_name}}, cast([{{embedding}}] as float[{{embedding_size}}])) as score from {{table_name}} limit {{top_n}} ), normalized_scores as ( select fts.uuid, fts.chunk, fts.path, fts.score as raw_fts_score, embd.score as raw_embd_score, (fts.score / (select max(score) from fts)) as norm_fts_score, ((embd.score + 1) / (select max(score) + 1 from embd)) as norm_embd_score from fts inner join embd on fts.uuid = embd.uuid ) select uuid, chunk, path, raw_fts_score, raw_embd_score, norm_fts_score, norm_embd_score, -- (alpha * norm_embd_score + (1-alpha) * norm_fts_score) (0.8*norm_embd_score + 0.2*norm_fts_score) AS score_cc from normalized_scores order by score_cc desc limit {{top_k}}; ================================================ FILE: swiftide-integrations/src/duckdb/mod.rs ================================================ use std::{ collections::HashMap, sync::{Arc, Mutex}, }; use anyhow::{Context as _, Result}; use derive_builder::Builder; use swiftide_core::{ indexing::{Chunk, EmbeddedField}, querying::search_strategies::HybridSearch, }; use tera::Context; use tokio::sync::RwLock; pub mod node_cache; pub mod persist; pub mod retrieve; const DEFAULT_INDEXING_SCHEMA: &str = include_str!("schema.sql"); const DEFAULT_UPSERT_QUERY: &str = include_str!("upsert.sql"); const DEFAULT_HYBRID_QUERY: &str = include_str!("hybrid_query.sql"); /// Provides `Persist`, `Retrieve`, and `NodeCache` for duckdb /// /// Unfortunately Metadata is not stored. /// /// Supports the following search strategies: /// - `SimilaritySingleEmbedding` /// - `HybridSearch` (<https://motherduck.com/blog/search-using-duckdb-part-3>/) /// - Custom /// /// NOTE: The integration is not optimized for ultra large datasets / load. It might work, if it /// doesn't let us know <3. #[derive(Clone, Builder)] #[builder(setter(into))] pub struct Duckdb<T: Chunk = String> { /// The connection to the database /// /// Note that this uses the tokio version of a mutex because the duckdb connection contains a /// `RefCell`. This is not ideal, but it is what it is. #[builder(setter(custom))] connection: Arc<Mutex<duckdb::Connection>>, /// The name of the table to use for storing nodes. Defaults to "swiftide". #[builder(default = "swiftide".into())] table_name: String, /// The schema to use for the table /// /// Note that if you change the schema, you probably also need to change the upsert query. /// /// Additionally, if you intend to use vectors, you must install and load the vss extension. #[builder(default = self.default_schema())] schema: String, // The vectors to be stored, field name -> size #[builder(default)] vectors: HashMap<EmbeddedField, usize>, /// Batch size for storing nodes #[builder(default = "256")] batch_size: usize, /// Sql to upsert a node #[builder(private, default = self.default_node_upsert_sql())] node_upsert_sql: String, /// Name of the table to use for caching nodes. Defaults to `"swiftide_cache"`. #[builder(default = "swiftide_cache".into())] cache_table: String, /// Tracks if the cache table has been created #[builder(private, default = Arc::new(false.into()))] cache_table_created: Arc<RwLock<bool>>, // note might need a mutex /// Prefix to be used for keys stored in the database to avoid collisions. Can be used to /// manually invalidate the cache. #[builder(default = "String::new()")] cache_key_prefix: String, /// If enabled, vectors will be upserted with an ON CONFLICT DO UPDATE. If disabled, ON /// conflict does nothing. Requires `duckdb` >= 1.2.1 #[builder(default)] #[allow(dead_code)] upsert_vectors: bool, #[builder(default)] chunk_type: std::marker::PhantomData<T>, } impl<T: Chunk> std::fmt::Debug for Duckdb<T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Duckdb") .field("connection", &"Arc<Mutex<duckdb::Connection>>") .field("table_name", &self.table_name) .field("batch_size", &self.batch_size) .finish() } } impl Duckdb<String> { pub fn builder() -> DuckdbBuilder<String> { DuckdbBuilder::<String>::default() } } impl<T: Chunk> Duckdb<T> { // pub fn builder() -> DuckdbBuilder<String> { // DuckdbBuilder::<String>::default() // } /// Name of the indexing table pub fn table_name(&self) -> &str { &self.table_name } /// Name of the cache table pub fn cache_table(&self) -> &str { &self.cache_table } /// Returns the connection to the database pub fn connection(&self) -> &Mutex<duckdb::Connection> { &self.connection } /// Creates HNSW indices on the vector fields /// /// These are *not* persisted. You must recreate them on startup. /// /// If you want to persist them, refer to the duckdb documentation. /// /// # Errors /// /// Errors if the connection or statement fails /// /// # Panics /// /// If the mutex locking the connection is poisoned pub fn create_vector_indices(&self) -> Result<()> { let table_name = &self.table_name; let mut conn = self.connection.lock().unwrap(); let tx = conn.transaction().context("Failed to start transaction")?; { for vector in self.vectors.keys() { tx.execute( &format!( "CREATE INDEX IF NOT EXISTS idx_{vector} ON {table_name} USING hnsw ({vector}) WITH (metric = 'cosine')", ), [], ) .context("Could not create index")?; } } tx.commit().context("Failed to commit transaction")?; Ok(()) } /// Safely creates the cache table if it does not exist. Can be used concurrently /// /// # Errors /// /// Errors if the table or index could not be created /// /// # Panics /// /// If the mutex locking the connection is poisoned pub async fn lazy_create_cache(&self) -> anyhow::Result<()> { if !*self.cache_table_created.read().await { let mut lock = self.cache_table_created.write().await; let conn = self.connection.lock().unwrap(); conn.execute( &format!( "CREATE TABLE IF NOT EXISTS {} (uuid TEXT PRIMARY KEY, path TEXT)", self.cache_table ), [], ) .context("Could not create table")?; // Create an extra index on path conn.execute( &format!( "CREATE INDEX IF NOT EXISTS idx_path ON {} (path)", self.cache_table ), [], ) .context("Could not create index")?; *lock = true; } Ok(()) } /// Formats a node key for the cache table pub fn node_key(&self, node: &swiftide_core::indexing::Node<T>) -> String { format!("{}.{}", self.cache_key_prefix, node.id()) } fn hybrid_query_sql( &self, search_strategy: &HybridSearch, query: &str, embedding: &[f32], ) -> Result<String> { let table_name = &self.table_name; // Silently ignores multiple vector fields let (field_name, embedding_size) = self .vectors .iter() .next() .context("No vectors configured")?; if self.vectors.len() > 1 { tracing::warn!( "Multiple vectors configured, but only the first one will be used: {:?}", self.vectors ); } let embedding = embedding .iter() .map(ToString::to_string) .collect::<Vec<_>>() .join(","); let context = Context::from_value(serde_json::json!({ "table_name": table_name, "top_n": search_strategy.top_n(), "top_k": search_strategy.top_k(), "embedding_name": field_name, "embedding_size": embedding_size, "query": wrap_and_escape(query), "embedding": embedding, }))?; let rendered = tera::Tera::one_off(DEFAULT_HYBRID_QUERY, &context, false)?; Ok(rendered) } } fn wrap_and_escape(s: &str) -> String { let quote = '\''; let mut buf = String::new(); buf.push(quote); let chars = s.chars(); for ch in chars { // escape `quote` by doubling it if ch == quote { buf.push(ch); } buf.push(ch); } buf.push(quote); buf } impl<T: Chunk> DuckdbBuilder<T> { pub fn connection(&mut self, connection: impl Into<duckdb::Connection>) -> &mut Self { self.connection = Some(Arc::new(Mutex::new(connection.into()))); self } pub fn with_vector(&mut self, field: EmbeddedField, size: usize) -> &mut Self { self.vectors .get_or_insert_with(HashMap::new) .insert(field, size); self } fn default_schema(&self) -> String { let mut context = Context::default(); context.insert("table_name", &self.table_name); context.insert("vectors", &self.vectors.clone().unwrap_or_default()); tera::Tera::one_off(DEFAULT_INDEXING_SCHEMA, &context, false) .expect("Could not render schema; infalllible") } fn default_node_upsert_sql(&self) -> String { let mut context = Context::default(); context.insert("table_name", &self.table_name); context.insert("vectors", &self.vectors.clone().unwrap_or_default()); context.insert("upsert_vectors", &self.upsert_vectors); context.insert( "vector_field_names", &self .vectors .as_ref() .map(|v| v.keys().collect::<Vec<_>>()) .unwrap_or_default(), ); tracing::info!("Rendering upsert sql"); tera::Tera::one_off(DEFAULT_UPSERT_QUERY, &context, false) .expect("could not render upsert query; infallible") } } ================================================ FILE: swiftide-integrations/src/duckdb/node_cache.rs ================================================ use anyhow::Context as _; use async_trait::async_trait; use swiftide_core::{ NodeCache, indexing::{Chunk, Node}, }; use super::Duckdb; macro_rules! unwrap_or_log { ($result:expr) => { match $result { Ok(value) => value, Err(e) => { tracing::error!("Error: {:#}", e); debug_assert!( true, "Duckdb should not give errors unless in very weird situations; this is a bug: {:#}", e ); return false; } } }; } #[async_trait] impl<T: Chunk> NodeCache for Duckdb<T> { type Input = T; async fn get(&self, node: &Node<T>) -> bool { unwrap_or_log!( self.lazy_create_cache() .await .context("failed to create cache table") ); let sql = format!( "SELECT EXISTS(SELECT 1 FROM {} WHERE uuid = ?)", &self.cache_table ); let lock = self.connection.lock().unwrap(); let mut stmt = unwrap_or_log!( lock.prepare(&sql) .context("Failed to prepare duckdb statement for persist") ); let present = unwrap_or_log!( stmt.query_map([self.node_key(node)], |row| row.get::<_, bool>(0)) .context("failed to query for documents") ) .next() .transpose(); unwrap_or_log!(present).unwrap_or(false) } async fn set(&self, node: &Node<T>) { if let Err(err) = self .lazy_create_cache() .await .context("failed to create cache table") { tracing::error!("Failed to create cache table: {:#}", err); return; } let sql = format!( "INSERT INTO {} (uuid, path) VALUES (?, ?) ON CONFLICT (uuid) DO NOTHING", &self.cache_table ); let lock = self.connection.lock().unwrap(); let mut stmt = match lock .prepare(&sql) .context("Failed to prepare duckdb statement for cache set") { Ok(stmt) => stmt, Err(err) => { tracing::error!( "Failed to prepare duckdb statement for cache set: {:#}", err ); return; } }; if let Err(err) = stmt .execute([self.node_key(node), node.path.to_string_lossy().into()]) .context("failed to insert into cache table") { tracing::error!("Failed to insert into cache table: {:#}", err); } } async fn clear(&self) -> anyhow::Result<()> { let sql = format!("DROP TABLE IF EXISTS {}", &self.cache_table); let lock = self.connection.lock().unwrap(); let mut stmt = lock .prepare(&sql) .context("Failed to prepare duckdb statement for cache clear")?; stmt.execute([]).context("failed to delete cache table")?; Ok(()) } } #[cfg(test)] mod tests { use super::*; use swiftide_core::indexing::TextNode; fn setup_duckdb() -> Duckdb { Duckdb::builder() .connection(duckdb::Connection::open_in_memory().unwrap()) .build() .unwrap() } #[tokio::test] async fn test_get_set() { let duckdb = setup_duckdb(); let node = TextNode::new("test_get_set"); assert!(!duckdb.get(&node).await); duckdb.set(&node).await; assert!(duckdb.get(&node).await); } #[tokio::test] async fn test_clear() { let duckdb = setup_duckdb(); let node = TextNode::new("test_clear"); duckdb.set(&node).await; assert!(duckdb.get(&node).await); duckdb.clear().await.unwrap(); assert!(!duckdb.get(&node).await); } } ================================================ FILE: swiftide-integrations/src/duckdb/persist.rs ================================================ use std::{ borrow::Cow, path::Path, sync::{LazyLock, Mutex as StdMutex}, }; use anyhow::{Context as _, Result}; use async_trait::async_trait; use duckdb::{ Statement, ToSql, params, params_from_iter, types::{ToSqlOutput, Value}, }; use swiftide_core::{ Persist, indexing::{self, Chunk, Metadata, Node}, }; use uuid::Uuid; use super::Duckdb; static DUCKDB_EXTENSION_INSTALL_LOCK: LazyLock<StdMutex<()>> = LazyLock::new(|| StdMutex::new(())); #[allow(dead_code)] enum TextNodeValues<'a> { Uuid(Uuid), Path(&'a Path), Chunk(&'a str), Metadata(&'a Metadata), Embedding(Cow<'a, [f32]>), Null, } impl ToSql for TextNodeValues<'_> { fn to_sql(&self) -> duckdb::Result<ToSqlOutput<'_>> { match self { TextNodeValues::Uuid(uuid) => Ok(ToSqlOutput::Owned(uuid.to_string().into())), // Should be borrow-able TextNodeValues::Path(path) => Ok(path.to_string_lossy().to_string().into()), TextNodeValues::Chunk(chunk) => chunk.to_sql(), TextNodeValues::Metadata(_metadata) => { unimplemented!("maps are not yet implemented for duckdb"); // Casting doesn't work either, the duckdb conversion is also not implemented :( } TextNodeValues::Embedding(vector) => { let array_str = format!( "[{}]", vector .iter() .map(ToString::to_string) .collect::<Vec<_>>() .join(",") ); Ok(ToSqlOutput::Owned(array_str.into())) } TextNodeValues::Null => Ok(ToSqlOutput::Owned(Value::Null)), } } } impl<T: Chunk + AsRef<str>> Duckdb<T> { #[allow(clippy::unused_self)] fn install_extensions(&self, conn: &duckdb::Connection) -> Result<()> { // DuckDB extension install writes to a shared on-disk extension directory. // Serializing installs avoids flaky concurrent install/load behavior in tests/CI. let _lock = DUCKDB_EXTENSION_INSTALL_LOCK.lock().unwrap(); conn.execute_batch(include_str!("extensions.sql")) .context("Failed to install duckdb extensions (vss, fts)")?; Ok(()) } fn store_node_on_stmt(&self, stmt: &mut Statement<'_>, node: &Node<T>) -> Result<()> { let mut values = vec![ TextNodeValues::Uuid(node.id()), TextNodeValues::Chunk(node.chunk.as_ref()), TextNodeValues::Path(&node.path), ]; let Some(node_vectors) = &node.vectors else { anyhow::bail!("Expected node to have vectors; cannot store into duckdb"); }; for field in self.vectors.keys() { let Some(vector) = node_vectors.get(field) else { anyhow::bail!("Expected vector for field {field} in node"); }; values.push(TextNodeValues::Embedding(vector.into())); } // TODO: Investigate concurrency in duckdb, maybe optmistic if it works stmt.execute(params_from_iter(values)) .context("Failed to store node")?; Ok(()) } } #[async_trait] impl<T: Chunk + AsRef<str>> Persist for Duckdb<T> { type Input = T; type Output = T; async fn setup(&self) -> Result<()> { tracing::debug!("Setting up duckdb schema"); { let conn = self.connection.lock().unwrap(); // Create if not exists does not seem to work with duckdb, so we check first if conn // Duckdb has issues with params it seems. .query_row(&format!("SHOW {}", self.table_name()), params![], |row| { row.get::<_, String>(0) }) .is_ok() { tracing::debug!("Indexing table already exists, skipping creation"); return Ok(()); } // Install extensions before schema loading so LOAD vss/fts in the schema succeeds. self.install_extensions(&conn)?; conn.execute_batch(&self.schema) .context("Failed to create indexing table")?; tracing::debug!(schema = &self.schema, "Indexing table created"); } tokio::time::sleep(std::time::Duration::from_secs(1)).await; { let conn = self.connection.lock().unwrap(); // We need to run this separately to ensure the table is created before we create the // index conn.execute_batch(&format!( "PRAGMA create_fts_index('{}', 'uuid', 'chunk', stemmer = 'porter', stopwords = 'english', ignore = '(\\.|[^a-z])+', strip_accents = 1, lower = 1, overwrite = 0); ", self.table_name ))?; } tracing::info!("Setup completed"); Ok(()) } async fn store(&self, node: indexing::Node<T>) -> Result<indexing::Node<T>> { let lock = self.connection.lock().unwrap(); let mut stmt = lock.prepare(&self.node_upsert_sql)?; self.store_node_on_stmt(&mut stmt, &node)?; Ok(node) } async fn batch_store(&self, nodes: Vec<indexing::Node<T>>) -> indexing::IndexingStream<T> { // TODO: Must batch let mut new_nodes = Vec::with_capacity(nodes.len()); tracing::debug!("Waiting for transaction"); let mut conn = self.connection.lock().unwrap(); tracing::debug!("Got transaction"); let tx = match conn.transaction().context("Failed to start transaction") { Ok(tx) => tx, Err(err) => { return Err(err).into(); } }; tracing::debug!("Starting batch store"); { let mut stmt = match tx .prepare(&self.node_upsert_sql) .context("Failed to prepare statement") { Ok(stmt) => stmt, Err(err) => { return Err(err).into(); } }; for node in nodes { new_nodes.push(self.store_node_on_stmt(&mut stmt, &node).map(|()| node)); } }; if let Err(err) = tx.commit().context("Failed to commit transaction") { return Err(err).into(); } new_nodes.into() } } #[cfg(test)] mod tests { use futures_util::TryStreamExt as _; use indexing::{EmbeddedField, TextNode}; use super::*; #[test_log::test(tokio::test)] async fn test_persisting_nodes() { let client = Duckdb::builder() .connection(duckdb::Connection::open_in_memory().unwrap()) .table_name("test".to_string()) .with_vector(EmbeddedField::Combined, 3) .build() .unwrap(); let node = TextNode::new("Hello duckdb!") .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])]) .to_owned(); client.setup().await.unwrap(); client.store(node.clone()).await.unwrap(); tracing::info!("Stored node"); { let connection = client.connection.lock().unwrap(); let mut stmt = connection .prepare("SELECT uuid,path,chunk FROM test") .unwrap(); let node_iter = stmt .query_map([], |row| { Ok(( row.get::<_, String>(0).unwrap(), // id row.get::<_, String>(1).unwrap(), // chunk row.get::<_, String>(2).unwrap(), // path )) }) .unwrap(); let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap(); // assert_eq!(retrieved.len(), 1); } tracing::info!("Retrieved node"); // Verify the upsert and batch works let new_nodes = vec![node.clone(), node.clone(), node.clone()]; let stream_nodes: Vec<TextNode> = client .batch_store(new_nodes) .await .try_collect() .await .unwrap(); // let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap(); assert_eq!(stream_nodes.len(), 3); assert_eq!(stream_nodes[0], node); tracing::info!("Batch stored nodes 1"); { let connection = client.connection.lock().unwrap(); let mut stmt = connection .prepare("SELECT uuid,path,chunk FROM test") .unwrap(); let node_iter = stmt .query_map([], |row| { Ok(( row.get::<_, String>(0).unwrap(), // id row.get::<_, String>(1).unwrap(), // chunk row.get::<_, String>(2).unwrap(), // path )) }) .unwrap(); let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap(); assert_eq!(retrieved.len(), 1); } // Test batch store fully let mut new_node = node.clone(); new_node.chunk = "Something else".into(); let new_nodes = vec![node.clone(), new_node.clone(), new_node.clone()]; let stream = client.batch_store(new_nodes).await; let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap(); assert_eq!(streamed_nodes.len(), 3); assert_eq!(streamed_nodes[0], node); { let connection = client.connection.lock().unwrap(); let mut stmt = connection .prepare("SELECT uuid,path,chunk FROM test") .unwrap(); let node_iter = stmt .query_map([], |row| { Ok(( row.get::<_, String>(0).unwrap(), // id row.get::<_, String>(1).unwrap(), // chunk row.get::<_, String>(2).unwrap(), // path )) }) .unwrap(); let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap(); assert_eq!(retrieved.len(), 2); } } #[ignore = "json types are acting up in duckdb at the moment"] #[test_log::test(tokio::test)] async fn test_with_metadata() { let client = Duckdb::builder() .connection(duckdb::Connection::open_in_memory().unwrap()) .table_name("test".to_string()) .with_vector(EmbeddedField::Combined, 3) .build() .unwrap(); let mut node = TextNode::new("Hello duckdb!") .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])]) .to_owned(); node.metadata .insert("filter".to_string(), "true".to_string()); client.setup().await.unwrap(); client.store(node).await.unwrap(); tracing::info!("Stored node"); let connection = client.connection.lock().unwrap(); let mut stmt = connection .prepare("SELECT uuid,path,chunk FROM test") .unwrap(); let node_iter = stmt .query_map([], |row| { Ok(( row.get::<_, String>(0).unwrap(), // id row.get::<_, String>(1).unwrap(), // chunk row.get::<_, String>(2).unwrap(), // path row.get::<_, Value>(3).unwrap(), // path // row.get::<_, String>(3).unwrap(), // metadata // row.get::<_, Vec<f32>>(4).unwrap(), // vector )) }) .unwrap(); let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap(); dbg!(&retrieved); // assert_eq!(retrieved.len(), 1); let Value::Map(metadata) = &retrieved[0].3 else { panic!("Expected metadata to be a map"); }; assert_eq!(metadata.keys().count(), 1); assert_eq!( metadata.get(&Value::Text("filter".into())).unwrap(), &Value::Text("true".into()) ); } #[test_log::test(tokio::test)] async fn test_running_setup_twice() { let client = Duckdb::builder() .connection(duckdb::Connection::open_in_memory().unwrap()) .table_name("test".to_string()) .with_vector(EmbeddedField::Combined, 3) .build() .unwrap(); client.setup().await.unwrap(); client.setup().await.unwrap(); // Should not panic or error } #[test_log::test(tokio::test)] async fn test_persisted() { let temp_db_path = temp_dir::TempDir::new().unwrap(); let temp_db_path = temp_db_path.path().join("test_duckdb.db"); let client = Duckdb::builder() .connection(duckdb::Connection::open(temp_db_path).unwrap()) .table_name("test".to_string()) .with_vector(EmbeddedField::Combined, 3) .build() .unwrap(); let mut node = TextNode::new("Hello duckdb!") .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])]) .to_owned(); node.metadata .insert("filter".to_string(), "true".to_string()); client.setup().await.unwrap(); client.store(node).await.unwrap(); tracing::info!("Stored node"); let connection = client.connection.lock().unwrap(); let mut stmt = connection .prepare("SELECT uuid,path,chunk FROM test") .unwrap(); let node_iter = stmt .query_map([], |row| { Ok(( row.get::<_, String>(0).unwrap(), // id row.get::<_, String>(1).unwrap(), // chunk row.get::<_, String>(2).unwrap(), // path )) }) .unwrap(); let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap(); dbg!(&retrieved); // assert_eq!(retrieved.len(), 1); } } ================================================ FILE: swiftide-integrations/src/duckdb/retrieve.rs ================================================ use anyhow::{Context as _, Result}; use async_trait::async_trait; use swiftide_core::{ Retrieve, indexing::Chunk, querying::{ Document, Query, search_strategies::{CustomStrategy, HybridSearch, SimilaritySingleEmbedding}, states, }, }; use super::Duckdb; #[async_trait] impl<T: Chunk> Retrieve<SimilaritySingleEmbedding> for Duckdb<T> { async fn retrieve( &self, search_strategy: &SimilaritySingleEmbedding, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { let Some(embedding) = query.embedding.as_ref() else { return Err(anyhow::Error::msg("Missing embedding in query state")); }; let table_name = &self.table_name; // Silently ignores multiple vector fields let (field_name, embedding_size) = self .vectors .iter() .next() .context("No vectors configured")?; let limit = search_strategy.top_k(); // Ideally it should be a prepared statement, where only the new parameters lead to extra // allocations. This is possible in 1.2.1, but that version is still broken for VSS via // Rust. let sql = format!( "SELECT uuid, chunk, path FROM {table_name}\n ORDER BY array_distance({field_name}, ARRAY[{}]::FLOAT[{embedding_size}])\n LIMIT {limit}", embedding .iter() .map(ToString::to_string) .collect::<Vec<_>>() .join(",") ); tracing::trace!("[duckdb] Executing query: {}", sql); let conn = self.connection().lock().unwrap(); let mut stmt = conn .prepare(&sql) .context("Failed to prepare duckdb statement for persist")?; tracing::trace!("[duckdb] Retrieving documents"); let documents = stmt .query_map([], |row| { Ok(Document::builder() .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)]) .content(row.get::<_, String>(1)?) .build() .expect("Failed to build document; should never happen")) }) .context("failed to query for documents")? .collect::<Result<Vec<Document>, _>>() .context("failed to build documents")?; tracing::debug!("[duckdb] Retrieved documents"); Ok(query.retrieved_documents(documents)) } } #[async_trait] impl<T: Chunk> Retrieve<CustomStrategy<String>> for Duckdb<T> { async fn retrieve( &self, search_strategy: &CustomStrategy<String>, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { let sql = search_strategy .build_query(&query) .await .context("Failed to build query")?; tracing::debug!("[duckdb] Executing query: {}", sql); let conn = self.connection().lock().unwrap(); let mut stmt = conn .prepare(&sql) .context("Failed to prepare duckdb statement for persist")?; tracing::debug!("[duckdb] Prepared statement"); let documents = stmt .query_map([], |row| { Ok(Document::builder() .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)]) .content(row.get::<_, String>(1)?) .build() .expect("Failed to build document; should never happen")) }) .context("failed to query for documents")? .collect::<Result<Vec<Document>, _>>() .context("failed to build documents")?; tracing::debug!("[duckdb] Retrieved documents"); Ok(query.retrieved_documents(documents)) } } #[async_trait] impl<T: Chunk> Retrieve<HybridSearch> for Duckdb<T> { async fn retrieve( &self, search_strategy: &HybridSearch, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { let Some(embedding) = query.embedding.as_ref() else { return Err(anyhow::Error::msg("Missing embedding in query state")); }; let sql = self .hybrid_query_sql(search_strategy, query.current(), embedding) .context("Failed to build query")?; tracing::debug!("[duckdb] Executing query: {}", sql); let conn = self.connection().lock().unwrap(); let mut stmt = conn .prepare(&sql) .context("Failed to prepare duckdb statement for persist")?; tracing::debug!("[duckdb] Prepared statement"); let documents = stmt // DuckDB has issues with using `params!` :( .query_map([], |row| { Ok(Document::builder() .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)]) .content(row.get::<_, String>(1)?) .build() .expect("Failed to build document; should never happen")) }) .context("failed to query for documents")? .collect::<Result<Vec<Document>, _>>() .context("failed to build documents")?; tracing::debug!("[duckdb] Retrieved documents"); Ok(query.retrieved_documents(documents)) } } #[cfg(test)] mod tests { use indexing::{EmbeddedField, TextNode}; use swiftide_core::{Persist as _, indexing}; use super::*; #[test_log::test(tokio::test)] async fn test_duckdb_retrieving_documents() { let client = Duckdb::builder() .connection(duckdb::Connection::open_in_memory().unwrap()) .table_name("test".to_string()) .with_vector(EmbeddedField::Combined, 3) .build() .unwrap(); let node = TextNode::new("Hello duckdb!") .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])]) .to_owned(); client.setup().await.unwrap(); client.store(node.clone()).await.unwrap(); tracing::info!("Stored node"); let query = Query::<states::Pending>::builder() .embedding(vec![1.0, 2.0, 3.0]) .original("Some query") .build() .unwrap(); let result = client .retrieve(&SimilaritySingleEmbedding::default(), query) .await .unwrap(); assert_eq!(result.documents().len(), 1); let document = result.documents().first().unwrap(); assert_eq!(document.content(), "Hello duckdb!"); assert_eq!( document.metadata().get("id").unwrap().as_str(), Some(node.id().to_string().as_str()) ); } #[test_log::test(tokio::test)] async fn test_duckdb_retrieving_documents_hybrid() { let client = Duckdb::builder() .connection(duckdb::Connection::open_in_memory().unwrap()) .table_name("test".to_string()) .with_vector(EmbeddedField::Combined, 3) .build() .unwrap(); let node = TextNode::new("Hello duckdb!") .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])]) .to_owned(); client.setup().await.unwrap(); client.store(node.clone()).await.unwrap(); tracing::info!("Stored node"); let query = Query::<states::Pending>::builder() .embedding(vec![1.0, 2.0, 3.0]) .original("Some query") .build() .unwrap(); let result = client .retrieve(&HybridSearch::default(), query) .await .unwrap(); assert_eq!(result.documents().len(), 1); let document = result.documents().first().unwrap(); assert_eq!(document.content(), "Hello duckdb!"); assert_eq!( document.metadata().get("id").unwrap().as_str(), Some(node.id().to_string().as_str()) ); } } ================================================ FILE: swiftide-integrations/src/duckdb/schema.sql ================================================ LOAD vss; LOAD fts; CREATE TABLE IF NOT EXISTS {{table_name}} ( uuid TEXT PRIMARY KEY, chunk TEXT NOT NULL, path TEXT, {% for vector, size in vectors %} {{vector}} FLOAT[{{size}}], {% endfor %} ); ================================================ FILE: swiftide-integrations/src/duckdb/upsert.sql ================================================ INSERT INTO {{ table_name }} (uuid, chunk, path, {{ vector_field_names | join(sep=", ") }}) VALUES (?, ?, ?, {% for _ in range(end=vector_field_names | length) %} ?, {% endfor %} ) {% if upsert_vectors -%} ON CONFLICT (uuid) DO UPDATE SET chunk = EXCLUDED.chunk, path = EXCLUDED.path, {% for vector in vector_field_names %} {{ vector }} = EXCLUDED.{{ vector }}, {% endfor %} {% else -%} ON CONFLICT (uuid) DO NOTHING {% endif -%} ; ================================================ FILE: swiftide-integrations/src/fastembed/embedding_model.rs ================================================ use anyhow::Result; use async_trait::async_trait; use swiftide_core::{EmbeddingModel, Embeddings, chat_completion::errors::LanguageModelError}; use super::{EmbeddingModelType, FastEmbed}; #[async_trait] impl EmbeddingModel for FastEmbed { #[tracing::instrument(skip_all)] async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> { let mut embedding_model = self.embedding_model.lock().await; match &mut *embedding_model { EmbeddingModelType::Dense(model) => model .embed(input, self.batch_size) .map_err(LanguageModelError::permanent), EmbeddingModelType::Sparse(_) => Err(LanguageModelError::PermanentError( "Expected dense model, got sparse".into(), )), } } } ================================================ FILE: swiftide-integrations/src/fastembed/mod.rs ================================================ //! `FastEmbed` integration for text embedding. use std::sync::Arc; use anyhow::Result; use derive_builder::Builder; use fastembed::{SparseTextEmbedding, TextEmbedding}; pub use swiftide_core::EmbeddingModel as _; pub use swiftide_core::SparseEmbeddingModel as _; mod embedding_model; mod rerank; mod sparse_embedding_model; pub use rerank::Rerank; pub enum EmbeddingModelType { Dense(TextEmbedding), Sparse(SparseTextEmbedding), } impl From<TextEmbedding> for EmbeddingModelType { fn from(val: TextEmbedding) -> Self { EmbeddingModelType::Dense(val) } } impl From<SparseTextEmbedding> for EmbeddingModelType { fn from(val: SparseTextEmbedding) -> Self { EmbeddingModelType::Sparse(val) } } /// Default batch size for embedding /// /// Matches the default batch size in [`fastembed`](https://docs.rs/fastembed) const DEFAULT_BATCH_SIZE: usize = 256; /// A wrapper around the `FastEmbed` library for text embedding. /// /// Supports a variety of fast text embedding models. The default is the `Flag Embedding` model /// with a dimension size of 384. /// /// A default can also be used for sparse embeddings, which by default uses Splade. Sparse /// embeddings are useful for more exact search in combination with dense vectors. /// /// `Into` is implemented for all available models from fastembed-rs. /// /// See the [FastEmbed documentation](https://docs.rs/fastembed) for more information on usage. /// /// `FastEmbed` can be customized by setting the embedding model via the builder. The batch size can /// also be set and is recommended. Batch size should match the batch size in the indexing /// pipeline. /// /// Note that the embedding vector dimensions need to match the dimensions of the vector database /// collection /// /// Requires the `fastembed` feature to be enabled. #[derive(Builder, Clone)] #[builder( pattern = "owned", setter(strip_option), build_fn(error = "anyhow::Error") )] pub struct FastEmbed { #[builder( setter(custom), default = "Arc::new(tokio::sync::Mutex::new(TextEmbedding::try_new(Default::default())?.into()))" )] embedding_model: Arc<tokio::sync::Mutex<EmbeddingModelType>>, #[builder(default = "Some(DEFAULT_BATCH_SIZE)")] batch_size: Option<usize>, } impl std::fmt::Debug for FastEmbed { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("FastEmbedBuilder") .field("batch_size", &self.batch_size) .finish() } } impl FastEmbed { /// Tries to build a default `FastEmbed` with `Flag Embedding`. /// /// # Errors /// /// Errors if the build fails pub fn try_default() -> Result<Self> { Self::builder().build() } /// Tries to build a default `FastEmbed` for sparse embeddings using Splade /// /// # Errors /// /// Errors if the build fails pub fn try_default_sparse() -> Result<Self> { Self::builder() .embedding_model(SparseTextEmbedding::try_new( fastembed::SparseInitOptions::default(), )?) .build() } pub fn builder() -> FastEmbedBuilder { FastEmbedBuilder::default() } } impl FastEmbedBuilder { #[must_use] pub fn embedding_model(mut self, fastembed: impl Into<EmbeddingModelType>) -> Self { self.embedding_model = Some(Arc::new(tokio::sync::Mutex::new(fastembed.into()))); self } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_fastembed() { let fastembed = FastEmbed::try_default().unwrap(); let embeddings = fastembed.embed(vec!["hello".to_string()]).await.unwrap(); assert_eq!(embeddings.len(), 1); } #[tokio::test] async fn test_sparse_fastembed() { let fastembed = FastEmbed::try_default_sparse().unwrap(); let embeddings = fastembed .sparse_embed(vec!["hello".to_string()]) .await .unwrap(); // Model can vary in size, assert it's small and not the full dictionary (30k+) assert!(embeddings[0].values.len() > 1); assert!(embeddings[0].values.len() < 100); assert_eq!(embeddings[0].indices.len(), embeddings[0].values.len()); } } ================================================ FILE: swiftide-integrations/src/fastembed/rerank.rs ================================================ use anyhow::{Context as _, Result}; use itertools::Itertools; use std::sync::Arc; use async_trait::async_trait; use derive_builder::Builder; use fastembed::{RerankInitOptions, TextRerank}; use swiftide_core::{ TransformResponse, querying::{Query, states}, }; const TOP_K: usize = 10; // NOTE: If ever more rerank models are added (outside fastembed). This should be refactored to a // generic implementation with textrerank behind an interface. // // NOTE: Additionally, controlling what gets used for reranking from the query side (maybe not just // the original?), is also something to be said for. The usecase hasn't popped up yet. /// Reranking with [`fastembed::TextRerank`] in a query pipeline. /// /// Uses the original user query to compare with the retrieved documents. Then updates the query /// with the `TOP_K` documents with the highest rerank score. /// /// Can be customized with any rerank model from `fastembed` and the number of top documents to /// return. Optionally you can provide a template to render the document before reranking. #[derive(Clone, Builder)] pub struct Rerank { /// The reranker model from [`Fastembed`] #[builder( default = "Arc::new(tokio::sync::Mutex::new(TextRerank::try_new(RerankInitOptions::default()).expect(\"Failed to build default rerank from Fastembed.rs\")))", setter(into) )] model: Arc<tokio::sync::Mutex<TextRerank>>, /// The number of top documents returned by the reranker. #[builder(default = TOP_K)] top_k: usize, /// Optionally a template can be provided to render the document /// before reranking. I.e. to include metadata in the reranking. /// /// Available variables are `metadata` and `content`. /// /// Templates are rendered using Tera. #[builder(default = None)] document_template: Option<String>, /// The rerank batch size to use. Defaults to the `Fastembed` default. #[builder(default = None)] model_batch_size: Option<usize>, } impl std::fmt::Debug for Rerank { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Rerank").finish() } } impl Rerank { pub fn builder() -> RerankBuilder { RerankBuilder::default() } } impl Default for Rerank { fn default() -> Self { Self { model: Arc::new(tokio::sync::Mutex::new( TextRerank::try_new(RerankInitOptions::default()) .expect("Failed to build default rerank from Fastembed.rs"), )), top_k: TOP_K, document_template: None, model_batch_size: None, } } } #[async_trait] impl TransformResponse for Rerank { async fn transform_response( &self, query: Query<states::Retrieved>, ) -> Result<Query<states::Retrieved>> { let mut query = query; let current_documents = std::mem::take(&mut query.documents); let docs_for_rerank = if let Some(template) = &self.document_template { current_documents .iter() .map(|doc| { let context = tera::Context::from_serialize(doc)?; tera::Tera::one_off(template, &context, false) .context("Failed to render template") }) .collect::<Result<Vec<_>>>()? } else { current_documents .iter() .map(|doc| doc.content().to_string()) .collect() }; let mut model = self.model.lock().await; let reranked_documents = model .rerank( query.original(), docs_for_rerank .iter() .map(String::as_ref) .collect::<Vec<&str>>(), false, self.model_batch_size, ) .map_err(|e| anyhow::anyhow!("Failed to rerank documents: {e:?}"))? .iter() .take(self.top_k) .map(|r| current_documents[r.index].clone()) .collect_vec(); query.documents = reranked_documents; Ok(query) } } #[cfg(test)] mod tests { use swiftide_core::{document::Document, indexing::Metadata}; use super::*; #[tokio::test] async fn test_rerank_transform_response() { // Test reranking without a template let rerank = Rerank::builder().top_k(1).build().unwrap(); let documents = vec!["content1", "content2", "content3"] .into_iter() .map(Into::into) .collect_vec(); let query = Query::builder() .original("What is the capital of france?") .state(states::Retrieved) .documents(documents) .build() .unwrap(); let result = rerank.transform_response(query).await; assert!(result.is_ok()); let transformed_query = result.unwrap(); assert_eq!(transformed_query.documents.len(), 1); // Test reranking with a template let rerank = Rerank::builder() .top_k(1) .document_template(Some("{{ metadata.title }}".to_string())) .build() .unwrap(); let metadata = Metadata::from([("title", "Title")]); let documents = vec!["content1", "content2", "content3"] .into_iter() .map(|content| Document::new(content, Some(metadata.clone()))) .collect_vec(); let query = Query::builder() .original("What is the capital of france?") .state(states::Retrieved) .documents(documents) .build() .unwrap(); let result = rerank.transform_response(query).await; assert!(result.is_ok()); let transformed_query = result.unwrap(); assert_eq!(transformed_query.documents.len(), 1); } } ================================================ FILE: swiftide-integrations/src/fastembed/sparse_embedding_model.rs ================================================ use async_trait::async_trait; use swiftide_core::chat_completion::errors::LanguageModelError; use swiftide_core::{SparseEmbedding, SparseEmbeddingModel, SparseEmbeddings}; use super::{EmbeddingModelType, FastEmbed}; #[async_trait] impl SparseEmbeddingModel for FastEmbed { #[tracing::instrument(skip_all)] async fn sparse_embed( &self, input: Vec<String>, ) -> Result<SparseEmbeddings, LanguageModelError> { let mut embedding_model = self.embedding_model.lock().await; match &mut *embedding_model { EmbeddingModelType::Sparse(model) => model .embed(input, self.batch_size) .map_err(LanguageModelError::permanent) .and_then(|embeddings| { embeddings .into_iter() .map(|embedding| { let indices = embedding .indices .iter() .map(|v| u32::try_from(*v).map_err(LanguageModelError::permanent)) .collect::<Result<Vec<_>, LanguageModelError>>()?; Ok(SparseEmbedding { indices, values: embedding.values, }) }) .collect() }), EmbeddingModelType::Dense(_) => Err(LanguageModelError::PermanentError( "Expected sparse model, got dense".into(), )), } } } ================================================ FILE: swiftide-integrations/src/fluvio/loader.rs ================================================ use std::string::ToString; use anyhow::Context as _; use futures_util::{StreamExt as _, TryStreamExt as _}; use swiftide_core::{ Loader, indexing::{IndexingStream, TextNode}, }; use tokio::runtime::Handle; use super::Fluvio; impl Loader for Fluvio { type Output = String; #[tracing::instrument] fn into_stream(self) -> IndexingStream<String> { let fluvio_config = self.fluvio_config; let consumer_config = self.consumer_config_ext; let stream = tokio::task::block_in_place(|| { Handle::current().block_on(async { let client = if let Some(fluvio_config) = &fluvio_config { fluvio::Fluvio::connect_with_config(fluvio_config).await } else { fluvio::Fluvio::connect().await } .context(format!("Failed to connect to Fluvio {fluvio_config:?}"))?; client.consumer_with_config(consumer_config).await }) }) .expect("Failed to connect to Fluvio"); let swiftide_stream = stream .map_ok(|f| { let mut node = TextNode::new(f.get_value().to_string()); node.metadata .insert("fluvio_key", f.get_key().map(ToString::to_string)); node }) .map_err(anyhow::Error::from); swiftide_stream.boxed().into() } fn into_stream_boxed(self: Box<Self>) -> IndexingStream<String> { self.into_stream() } } #[cfg(test)] mod tests { use std::pin::Pin; use super::*; use anyhow::Result; use fluvio::{ RecordKey, consumer::ConsumerConfigExt, metadata::{customspu::CustomSpuSpec, topic::TopicSpec}, }; use flv_util::socket_helpers::ServerAddress; use futures_util::TryStreamExt; use regex::Regex; use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner}; use tokio::io::{AsyncBufRead, AsyncBufReadExt}; // NOTE: Move to test-utils / upstream to testcontainers if needed elsewhere struct FluvioCluster { sc: ContainerAsync<GenericImage>, spu: ContainerAsync<GenericImage>, partitions: u32, replicas: u32, port: u16, host_spu_port: u16, client: fluvio::Fluvio, } impl FluvioCluster { // Starts a fluvio cluster and connects the spu to the sc pub async fn start() -> Result<FluvioCluster> { static SC_PORT: u16 = 9003; static SPU_PORT1: u16 = 9010; static SPU_PORT2: u16 = 9011; static NETWORK_NAME: &str = "fluvio"; static PARTITIONS: u32 = 1; static REPLICAS: u32 = 1; let sc = GenericImage::new("infinyon/fluvio", "latest") .with_exposed_port(SC_PORT.into()) .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( "started successfully", )) .with_wait_for(testcontainers::core::WaitFor::seconds(1)) .with_network(NETWORK_NAME) .with_container_name("sc") .with_cmd("./fluvio-run sc --local /fluvio/metadata".split(' ')) .with_env_var("RUST_LOG", "info") .start() .await?; let spu = GenericImage::new("infinyon/fluvio", "latest") .with_exposed_port(SPU_PORT1.into()) .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( "started successfully", )) .with_wait_for(testcontainers::core::WaitFor::seconds(1)) .with_network(NETWORK_NAME) .with_container_name("spu") .with_cmd(format!("./fluvio-run spu -i 5001 -p spu:{SPU_PORT1} -v spu:{SPU_PORT2} --sc-addr sc:9004 --log-base-dir /fluvio/data").split(' ')) .with_env_var("RUST_LOG", "info") .start() .await?; let host_spu_port_1 = spu.get_host_port_ipv4(SPU_PORT1).await?; let sc_host_port = sc.get_host_port_ipv4(SC_PORT).await?; let endpoint = format!("127.0.0.1:{sc_host_port}"); let config = fluvio::FluvioConfig::new(&endpoint); let client = fluvio::Fluvio::connect_with_config(&config).await?; let cluster = FluvioCluster { sc, spu, port: sc_host_port, host_spu_port: host_spu_port_1, client, replicas: REPLICAS, partitions: PARTITIONS, }; cluster.connect_spu_to_sc().await; Ok(cluster) } async fn connect_spu_to_sc(&self) { let admin = self.client().admin().await; let spu_spec = CustomSpuSpec { id: 5001, public_endpoint: ServerAddress::try_from(format!("0.0.0.0:{}", self.host_spu_port)) .unwrap() .into(), private_endpoint: ServerAddress::try_from(format!("spu:{}", 9011)) .unwrap() .into(), rack: None, public_endpoint_local: None, }; admin .create("SPU".to_string(), false, spu_spec) .await .unwrap(); } pub fn forward_logs_to_tracing(&self) { Self::log_stdout(self.sc.stdout(true)); Self::log_stderr(self.sc.stderr(true)); Self::log_stdout(self.spu.stdout(true)); Self::log_stderr(self.spu.stderr(true)); } pub fn client(&self) -> &fluvio::Fluvio { &self.client } pub async fn create_topic(&self, topic_name: impl Into<String>) -> Result<()> { let admin = self.client().admin().await; let topic_spec = TopicSpec::new_computed(self.partitions, self.replicas, None); admin.create(topic_name.into(), false, topic_spec).await } fn log_stdout(reader: Pin<Box<dyn AsyncBufRead + Send>>) { let regex = Self::ansii_regex(); tokio::spawn(async move { let mut lines = reader.lines(); while let Some(line) = lines.next_line().await.unwrap() { let line = regex.replace_all(&line, "").to_string(); tracing::info!(line); } }); } fn log_stderr(reader: Pin<Box<dyn AsyncBufRead + Send>>) { let regex = Self::ansii_regex(); tokio::spawn(async move { let mut lines = reader.lines(); while let Some(line) = lines.next_line().await.unwrap() { let line = regex.replace_all(&line, "").to_string(); tracing::error!(line); } }); } fn ansii_regex() -> Regex { regex::Regex::new(r"\x1b\[([\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e])").unwrap() } pub fn endpoint(&self) -> String { format!("127.0.0.1:{}", self.port) } } #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_fluvio_loader() { static TOPIC_NAME: &str = "hello-rust"; static PARTITION_NUM: u32 = 0; let fluvio_cluster = FluvioCluster::start() .await .expect("Failed to start Fluvio cluster"); fluvio_cluster.forward_logs_to_tracing(); fluvio_cluster.create_topic(TOPIC_NAME).await.unwrap(); let client = fluvio_cluster.client(); let producer = client.topic_producer(TOPIC_NAME).await.unwrap(); producer .send(RecordKey::NULL, "Hello fluvio") .await .unwrap(); producer.flush().await.unwrap(); // Consume the topic with the loader let config = fluvio::FluvioConfig::new(fluvio_cluster.endpoint()); let loader = Fluvio::builder() .fluvio_config(&config) .consumer_config_ext( ConsumerConfigExt::builder() .topic(TOPIC_NAME) .partition(PARTITION_NUM) .offset_start(fluvio::Offset::from_end(1)) .build() .unwrap(), ) .build() .unwrap(); let node: TextNode = loader.into_stream().try_next().await.unwrap().unwrap(); assert_eq!(node.chunk, "Hello fluvio"); } } ================================================ FILE: swiftide-integrations/src/fluvio/mod.rs ================================================ //! Fluvio is a real-time streaming data transformation platform. //! //! This module provides a Fluvio loader for Swiftide and allows you to ingest //! messages from Fluvio topics and use them for RAG. //! //! Can be configured with [`ConsumerConfigExt`]. //! //! # Example //! //! ```no_run //! # use swiftide_integrations::fluvio::*; //! let loader = Fluvio::builder() //! .consumer_config_ext( //! ConsumerConfigExt::builder() //! .topic("Hello Fluvio") //! .partition(0) //! .offset_start(fluvio::Offset::from_end(1)) //! .build().unwrap() //! ).build().unwrap(); //! ``` use derive_builder::Builder; use fluvio::FluvioConfig; /// Re-export the fluvio config builder pub use fluvio::consumer::{ConsumerConfigExt, ConsumerConfigExtBuilder}; mod loader; #[derive(Debug, Clone, Builder)] #[builder(setter(into, strip_option))] pub struct Fluvio { /// The Fluvio consumer configuration to use. consumer_config_ext: ConsumerConfigExt, #[builder(default, setter(custom))] /// Custom connection configuration fluvio_config: Option<FluvioConfig>, } impl Fluvio { /// Creates a new Fluvio instance from a consumer extended configuration pub fn from_consumer_config(config: impl Into<ConsumerConfigExt>) -> Fluvio { Fluvio { consumer_config_ext: config.into(), fluvio_config: None, } } pub fn builder() -> FluvioBuilder { FluvioBuilder::default() } } impl FluvioBuilder { pub fn fluvio_config(&mut self, config: &FluvioConfig) -> &mut Self { self.fluvio_config = Some(Some(config.to_owned())); self } } ================================================ FILE: swiftide-integrations/src/gemini/config.rs ================================================ use reqwest::header::{AUTHORIZATION, HeaderMap}; use secrecy::{ExposeSecret as _, SecretString}; use serde::Deserialize; const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai"; #[derive(Clone, Debug, Deserialize)] #[serde(default)] pub struct GeminiConfig { api_base: String, api_key: SecretString, } impl Default for GeminiConfig { fn default() -> Self { Self { api_base: GEMINI_API_BASE.to_string(), api_key: std::env::var("GEMINI_API_KEY") .unwrap_or_else(|_| String::new()) .into(), } } } impl async_openai::config::Config for GeminiConfig { fn headers(&self) -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert( AUTHORIZATION, format!("Bearer {}", self.api_key.expose_secret()) .as_str() .parse() .unwrap(), ); headers } fn url(&self, path: &str) -> String { format!("{}{}", self.api_base, path) } fn api_base(&self) -> &str { &self.api_base } fn api_key(&self) -> &SecretString { &self.api_key } fn query(&self) -> Vec<(&str, &str)> { vec![] } } ================================================ FILE: swiftide-integrations/src/gemini/mod.rs ================================================ //! This module provides integration with `Gemini`'s API, enabling the use of language models within //! the Swiftide project. It includes the `Gemini` struct for managing API clients and default //! options for prompt models. The module is conditionally compiled based on the "groq" feature //! flag. use crate::openai; use self::config::GeminiConfig; mod config; /// The `Gemini` struct encapsulates a `Gemini` client that implements /// [`swiftide_core::SimplePrompt`] /// /// There is also a builder available. /// /// By default it will look for a `GEMINI_API_KEY` environment variable. Note that a model /// always needs to be set, either with [`Gemini::with_default_prompt_model`] or via the builder. /// You can find available models in the Gemini documentation. /// /// Under the hood it uses [`async_openai`], with the Gemini openai mapping. This means /// some features might not work as expected. See the Gemini documentation for details. pub type Gemini = openai::GenericOpenAI<GeminiConfig>; pub type GeminiBuilder = openai::GenericOpenAIBuilder<GeminiConfig>; pub type GeminiBuilderError = openai::GenericOpenAIBuilderError; pub use openai::{Options, OptionsBuilder, OptionsBuilderError}; impl Gemini { pub fn builder() -> GeminiBuilder { GeminiBuilder::default() } } impl Default for Gemini { fn default() -> Self { Self::builder().build().unwrap() } } ================================================ FILE: swiftide-integrations/src/groq/config.rs ================================================ use reqwest::header::{AUTHORIZATION, HeaderMap}; use secrecy::{ExposeSecret as _, SecretString}; use serde::Deserialize; const GROQ_API_BASE: &str = "https://api.groq.com/openai/v1"; #[derive(Clone, Debug, Deserialize)] #[serde(default)] pub struct GroqConfig { api_base: String, api_key: SecretString, } impl Default for GroqConfig { fn default() -> Self { Self { api_base: GROQ_API_BASE.to_string(), api_key: std::env::var("GROQ_API_KEY") .unwrap_or_else(|_| String::new()) .into(), } } } impl async_openai::config::Config for GroqConfig { fn headers(&self) -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert( AUTHORIZATION, format!("Bearer {}", self.api_key.expose_secret()) .as_str() .parse() .unwrap(), ); headers } fn url(&self, path: &str) -> String { format!("{}{}", self.api_base, path) } fn api_base(&self) -> &str { &self.api_base } fn api_key(&self) -> &SecretString { &self.api_key } fn query(&self) -> Vec<(&str, &str)> { vec![] } } ================================================ FILE: swiftide-integrations/src/groq/mod.rs ================================================ //! This module provides integration with `Groq`'s API, enabling the use of language models within //! the Swiftide project. It includes the `Groq` struct for managing API clients and default options //! for prompt models. The module is conditionally compiled based on the "groq" feature flag. use crate::openai; use self::config::GroqConfig; mod config; /// The `Groq` struct encapsulates a `Groq` client that implements [`swiftide_core::SimplePrompt`] /// /// There is also a builder available. /// /// By default it will look for a `GROQ_API_KEY` environment variable. Note that a model /// always needs to be set, either with [`Groq::with_default_prompt_model`] or via the builder. /// You can find available models in the Groq documentation. /// /// Under the hood it uses [`async_openai`], with the Groq openai mapping. This means /// some features might not work as expected. See the Groq documentation for details. pub type Groq = openai::GenericOpenAI<GroqConfig>; pub type GroqBuilder = openai::GenericOpenAIBuilder<GroqConfig>; pub type GroqBuilderError = openai::GenericOpenAIBuilderError; pub use openai::{Options, OptionsBuilder, OptionsBuilderError}; impl Groq { pub fn builder() -> GroqBuilder { GroqBuilder::default() } } impl Default for Groq { fn default() -> Self { Self::builder().build().unwrap() } } ================================================ FILE: swiftide-integrations/src/kafka/loader.rs ================================================ use futures_util::{StreamExt as _, stream}; use rdkafka::{ Message, consumer::{Consumer, StreamConsumer}, message::BorrowedMessage, }; use swiftide_core::{Loader, indexing::IndexingStream, indexing::Node}; use super::Kafka; impl Loader for Kafka { type Output = String; #[tracing::instrument] fn into_stream(self) -> IndexingStream<String> { let client_config = self.client_config; let topic = self.topic.clone(); let consumer: StreamConsumer = client_config .create() .expect("Failed to create Kafka consumer"); consumer .subscribe(&[&topic]) .expect("Failed to subscribe to topic"); let swiftide_stream = stream::unfold(consumer, |consumer| async move { loop { match consumer.recv().await { Ok(message) => { // only handle Some(Ok(s)) if let Some(Ok(payload)) = message.payload_view::<str>() { let mut node = Node::<String>::new(payload); msg_metadata(&mut node, &message); tracing::trace!(?node, ?payload, "received message"); return Some((Ok(node), consumer)); } // otherwise, like a message with an invalid payload or payload is None tracing::debug!("Skipping message with invalid payload"); } Err(e) => return Some((Err(anyhow::Error::from(e)), consumer)), } } }); swiftide_stream.boxed().into() } fn into_stream_boxed(self: Box<Self>) -> IndexingStream<String> { (*self).into_stream() } } fn msg_metadata(node: &mut Node<String>, message: &BorrowedMessage) { // Add Kafka-specific metadata node.metadata .insert("kafka_topic", message.topic().to_string()); node.metadata .insert("kafka_partition", message.partition().to_string()); node.metadata .insert("kafka_offset", message.offset().to_string()); // Add timestamp if present if let Some(timestamp) = message.timestamp().to_millis() { node.metadata .insert("kafka_timestamp", timestamp.to_string()); } // Add key if present if let Some(Ok(key)) = message.key_view::<str>() { node.metadata.insert("kafka_key", key.to_string()); } } #[cfg(test)] mod tests { use std::time::Duration; use super::*; use crate::kafka::Kafka; use anyhow::Result; use futures_util::TryStreamExt; use rdkafka::{ ClientConfig, admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}, client::DefaultClientContext, producer::{FutureProducer, FutureRecord, Producer}, }; use swiftide_core::indexing::TextNode; use testcontainers::{ContainerAsync, runners::AsyncRunner}; use testcontainers_modules::kafka::apache::{self}; struct KafkaBroker { _broker: ContainerAsync<apache::Kafka>, partitions: i32, replicas: i32, client_config: ClientConfig, } impl KafkaBroker { pub async fn start() -> Result<Self> { static PARTITIONS: i32 = 1; static REPLICAS: i32 = 1; let kafka_node = apache::Kafka::default().start().await?; let bootstrap_servers = format!( "127.0.0.1:{}", kafka_node.get_host_port_ipv4(apache::KAFKA_PORT).await? ); let mut client_config = ClientConfig::new(); client_config.set("bootstrap.servers", &bootstrap_servers); client_config.set("group.id", "group_id"); client_config.set("auto.offset.reset", "earliest"); let broker = KafkaBroker { _broker: kafka_node, client_config, partitions: PARTITIONS, replicas: REPLICAS, }; Ok(broker) } pub async fn create_topic(&self, topic: impl AsRef<str>) -> Result<()> { let admin = self.admin_client(); admin .create_topics( &[NewTopic { name: topic.as_ref(), num_partitions: self.partitions, replication: TopicReplication::Fixed(self.replicas), config: vec![], }], &AdminOptions::default(), ) .await .expect("topic creation failed"); Ok(()) } fn admin_client(&self) -> AdminClient<DefaultClientContext> { self.client_config.create().unwrap() } fn producer(&self) -> FutureProducer { self.client_config.create().unwrap() } } #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_kafka_loader() { static TOPIC_NAME: &str = "topic"; let kafka_broker = KafkaBroker::start().await.unwrap(); kafka_broker.create_topic(TOPIC_NAME).await.unwrap(); let producer = kafka_broker.producer(); producer .send( FutureRecord::to(TOPIC_NAME).payload("payload").key("key"), Duration::from_secs(0), ) .await .unwrap(); producer.flush(Duration::from_secs(0)).unwrap(); let loader = Kafka::builder() .client_config(kafka_broker.client_config.clone()) .topic(TOPIC_NAME) .build() .unwrap(); let node: TextNode = loader.into_stream().try_next().await.unwrap().unwrap(); assert_eq!(node.chunk, "payload"); } } ================================================ FILE: swiftide-integrations/src/kafka/mod.rs ================================================ //! Kafka is a distributed streaming platform. //! //! This module provides a Kafka loader for Swiftide and allows you to ingest //! messages from Kafka topics and use them for RAG. //! //! Can be configured with [`ClientConfig`]. //! //! # Example //! //! ```no_run //! # use swiftide_integrations::kafka::*; //! let kafka = Kafka::builder() //! .client_config(ClientConfig::new()) //! .topic("Hello Kafka") //! .build().unwrap(); //! ``` use anyhow::{Context, Result}; use derive_builder::Builder; use rdkafka::{ admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}, client::DefaultClientContext, consumer::{Consumer, StreamConsumer}, producer::FutureProducer, }; use swiftide_core::indexing::TextNode; pub use rdkafka::config::ClientConfig; mod loader; mod persist; #[derive(Debug, Clone, Builder)] #[builder(setter(into, strip_option))] pub struct Kafka { client_config: ClientConfig, topic: String, #[builder(default)] /// Customize the key used for persisting nodes persist_key_fn: Option<fn(&TextNode) -> Result<String>>, #[builder(default)] /// Customize the value used for persisting nodes persist_payload_fn: Option<fn(&TextNode) -> Result<String>>, #[builder(default = "1")] partition: i32, #[builder(default = "1")] factor: i32, #[builder(default)] create_topic_if_not_exists: bool, #[builder(default = "32")] batch_size: usize, } impl Kafka { pub fn from_client_config(config: impl Into<ClientConfig>, topic: impl Into<String>) -> Kafka { Kafka { client_config: config.into(), topic: topic.into(), persist_key_fn: None, persist_payload_fn: None, partition: 1, factor: 1, create_topic_if_not_exists: false, batch_size: 32, } } pub fn builder() -> KafkaBuilder { KafkaBuilder::default() } fn producer(&self) -> Result<FutureProducer<DefaultClientContext>> { self.client_config .create() .context("Failed to create producer") } fn topic_exists(&self) -> Result<bool> { let consumer: StreamConsumer = self .client_config .create() .context("Failed to create consumer")?; let metadata = consumer.fetch_metadata(Some(&self.topic), None)?; Ok(!metadata.topics().is_empty()) } async fn create_topic(&self) -> Result<()> { let admin_client: AdminClient<DefaultClientContext> = self .client_config .create() .context("Failed to create admin client")?; admin_client .create_topics( vec![&NewTopic::new( &self.topic, self.partition, TopicReplication::Fixed(self.factor), )], &AdminOptions::new(), ) .await?; Ok(()) } /// Generates a ky for a given node to be persisted in Kafka. fn persist_key_for_node(&self, node: &TextNode) -> Result<String> { if let Some(key_fn) = self.persist_key_fn { key_fn(node) } else { let hash = node.id(); Ok(format!("{}:{}", node.path.to_string_lossy(), hash)) } } /// Generates a value for a given node to be persisted in Kafka. /// By default, the node is serialized as JSON. /// If a custom function is provided, it is used to generate the value. /// Otherwise, the node is serialized as JSON. fn persist_value_for_node(&self, node: &TextNode) -> Result<String> { if let Some(value_fn) = self.persist_payload_fn { value_fn(node) } else { Ok(serde_json::to_string(node)?) } } fn node_to_key_payload(&self, node: &TextNode) -> Result<(String, String)> { let key = self .persist_key_for_node(node) .map_err(|e| anyhow::anyhow!("persist_key_for_node failed: {e:?} (node: {node:?})"))?; let payload = self.persist_value_for_node(node).map_err(|e| { anyhow::anyhow!("persist_value_for_node failed: {e:?} (node: {node:?})") })?; Ok((key, payload)) } } ================================================ FILE: swiftide-integrations/src/kafka/persist.rs ================================================ use std::{sync::Arc, time::Duration}; use anyhow::Result; use async_trait::async_trait; use rdkafka::producer::FutureRecord; use swiftide_core::{ Persist, indexing::{IndexingStream, TextNode}, }; use super::Kafka; #[async_trait] impl Persist for Kafka { type Input = String; type Output = String; async fn setup(&self) -> Result<()> { if self.topic_exists()? { return Ok(()); } if !self.create_topic_if_not_exists { return Err(anyhow::anyhow!("Topic {} does not exist", self.topic)); } self.create_topic().await?; Ok(()) } fn batch_size(&self) -> Option<usize> { Some(self.batch_size) } async fn store(&self, node: TextNode) -> Result<TextNode> { let (key, payload) = self.node_to_key_payload(&node)?; self.producer()? .send( FutureRecord::to(&self.topic).key(&key).payload(&payload), Duration::from_secs(0), ) .await .map_err(|(e, _)| anyhow::anyhow!("Failed to send node: {e:?}"))?; Ok(node) } async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> { let producer = Arc::new(self.producer().expect("Failed to create producer")); for node in &nodes { match self.node_to_key_payload(node) { Ok((key, payload)) => { if let Err(e) = producer .send( FutureRecord::to(&self.topic).payload(&payload).key(&key), Duration::from_secs(0), ) .await { return vec![Err(anyhow::anyhow!("failed to send node: {e:?}"))].into(); } } Err(e) => { return vec![Err(e)].into(); } } } IndexingStream::iter(nodes.into_iter().map(Ok)) } } #[cfg(test)] mod tests { use super::*; use futures_util::TryStreamExt; use rdkafka::ClientConfig; use testcontainers::runners::AsyncRunner; use testcontainers_modules::kafka::apache::{self}; #[test_log::test(tokio::test)] async fn test_kafka_persist() { static TOPIC_NAME: &str = "topic"; let kafka_node = apache::Kafka::default() .start() .await .expect("failed to start kafka"); let bootstrap_servers = format!( "127.0.0.1:{}", kafka_node .get_host_port_ipv4(apache::KAFKA_PORT) .await .expect("failed to get kafka port") ); let mut client_config = ClientConfig::new(); client_config.set("bootstrap.servers", &bootstrap_servers); let storage = Kafka::builder() .client_config(client_config) .topic(TOPIC_NAME) .build() .unwrap(); let node = TextNode::new("chunk"); storage.setup().await.unwrap(); storage.store(node.clone()).await.unwrap(); } #[test_log::test(tokio::test)] async fn test_kafka_batch_persist() { static TOPIC_NAME: &str = "topic"; let kafka_node = apache::Kafka::default() .start() .await .expect("failed to start kafka"); let bootstrap_servers = format!( "127.0.0.1:{}", kafka_node .get_host_port_ipv4(apache::KAFKA_PORT) .await .expect("failed to get kafka port") ); let mut client_config = ClientConfig::new(); client_config.set("bootstrap.servers", &bootstrap_servers); let storage = Kafka::builder() .client_config(client_config) .topic(TOPIC_NAME) .create_topic_if_not_exists(true) .batch_size(2usize) .build() .unwrap(); let nodes = vec![TextNode::default(); 6]; storage.setup().await.unwrap(); let stream = storage.batch_store(nodes.clone()).await; let result: Vec<TextNode> = stream.try_collect().await.unwrap(); assert_eq!(result.len(), 6); assert_eq!(result[0], nodes[0]); assert_eq!(result[1], nodes[1]); assert_eq!(result[2], nodes[2]); assert_eq!(result[3], nodes[3]); assert_eq!(result[4], nodes[4]); assert_eq!(result[5], nodes[5]); } } ================================================ FILE: swiftide-integrations/src/lancedb/connection_pool.rs ================================================ use anyhow::Context as _; use anyhow::Result; use deadpool::managed::Manager; use derive_builder::Builder; use lancedb::connection::ConnectBuilder; #[derive(Builder, Debug, Clone)] #[builder(setter(into), build_fn(error = "anyhow::Error"))] pub struct LanceDBPoolManager { uri: String, #[builder(default)] api_key: Option<String>, #[builder(default)] region: Option<String>, #[builder(default)] storage_options: Vec<(String, String)>, } pub type LanceDBConnectionPool = deadpool::managed::Pool<LanceDBPoolManager>; impl LanceDBPoolManager { pub fn builder() -> LanceDBPoolManagerBuilder { LanceDBPoolManagerBuilder::default() } } impl Manager for LanceDBPoolManager { type Type = lancedb::Connection; type Error = anyhow::Error; async fn create(&self) -> Result<Self::Type, Self::Error> { let mut builder = ConnectBuilder::new(&self.uri); if let Some(api_key) = &self.api_key { builder = builder.api_key(api_key); } if let Some(region) = &self.region { builder = builder.region(region); } for (key, value) in &self.storage_options { builder = builder.storage_option(key, value); } builder .execute() .await .context("Failed to create LanceDB connection") } async fn recycle( &self, _obj: &mut Self::Type, _metrics: &deadpool::managed::Metrics, ) -> deadpool::managed::RecycleResult<Self::Error> { // NOTE: Should work fine with drop Ok(()) } } ================================================ FILE: swiftide-integrations/src/lancedb/mod.rs ================================================ use std::sync::Arc; use anyhow::Context as _; use anyhow::Result; use connection_pool::LanceDBConnectionPool; use connection_pool::LanceDBPoolManager; use deadpool::managed::Object; use derive_builder::Builder; use lancedb::arrow::arrow_schema::{DataType, Field, Schema}; use swiftide_core::indexing::EmbeddedField; pub mod connection_pool; pub mod persist; pub mod retrieve; /// `LanceDB` is a columnar database that separates data and compute. /// /// This enables local, embedded databases, or storing in a cloud storage. /// /// See examples for more information. /// /// Implements `Persist` and `Retrieve`. /// /// If you want to store / retrieve metadata in Lance, the columns can be defined with /// `with_metadata`. /// /// Note: For querying large tables you manually need to create an index. You can get an /// active connection via `get_connection`. /// /// # Example /// /// ```no_run /// # use swiftide_integrations::lancedb::{LanceDB}; /// # use swiftide_core::indexing::EmbeddedField; /// LanceDB::builder() /// .uri("/my/lancedb") /// .vector_size(1536) /// .with_vector(EmbeddedField::Combined) /// .with_metadata("Metadata field to also store") /// .table_name("swiftide_test") /// .build() /// .unwrap(); #[derive(Builder, Clone)] #[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] #[allow(dead_code)] pub struct LanceDB { /// Connection pool for `LanceDB` /// By default will use settings provided when creating the instance. #[builder(default = "self.default_connection_pool()?")] connection_pool: Arc<LanceDBConnectionPool>, /// Set the URI. Required unless a connection pool is provided. uri: Option<String>, /// The maximum number of connections, defaults to 10. #[builder(default = "Some(10)")] pool_size: Option<usize>, /// Optional API key #[builder(default)] api_key: Option<String>, /// Optional Region #[builder(default)] region: Option<String>, /// Storage options #[builder(default)] storage_options: Vec<(String, String)>, #[builder(private, default = "self.default_schema_from_fields()")] schema: Arc<Schema>, /// The name of the table to store the data /// By default will use `swiftide` #[builder(default = "\"swiftide\".into()")] table_name: String, /// Default sizes of vectors. Vectors can also be of different /// sizes by specifying the size in the vector configuration. vector_size: Option<i32>, /// Batch size for storing nodes in `LanceDB`. Default is 256. #[builder(default = "256")] batch_size: usize, /// Field configuration for `LanceDB`, will result in the eventual schema. /// /// Supports multiple field types, see [`FieldConfig`] for more details. #[builder(default = "self.default_fields()")] fields: Vec<FieldConfig>, } impl std::fmt::Debug for LanceDB { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("LanceDB") .field("schema", &self.schema) .finish() } } impl LanceDB { pub fn builder() -> LanceDBBuilder { LanceDBBuilder::default() } /// Get a connection to `LanceDB` from the connection pool /// /// # Errors /// /// Returns an error if the connection cannot be retrieved. pub async fn get_connection(&self) -> Result<Object<LanceDBPoolManager>> { Box::pin(self.connection_pool.get()) .await .map_err(|e| anyhow::anyhow!(e)) } /// Opens the lancedb table /// /// # Errors /// /// Returns an error if the table cannot be opened or the connection cannot be acquired. pub async fn open_table(&self) -> Result<lancedb::Table> { let conn = self.get_connection().await?; conn.open_table(&self.table_name) .execute() .await .context("Failed to open table") } } impl LanceDBBuilder { #[allow(clippy::missing_panics_doc)] pub fn with_vector(&mut self, config: impl Into<VectorConfig>) -> &mut Self { if self.fields.is_none() { self.fields(self.default_fields()); } self.fields .as_mut() .unwrap() .push(FieldConfig::Vector(config.into())); self } #[allow(clippy::missing_panics_doc)] pub fn with_metadata(&mut self, config: impl Into<MetadataConfig>) -> &mut Self { if self.fields.is_none() { self.fields(self.default_fields()); } self.fields .as_mut() .unwrap() .push(FieldConfig::Metadata(config.into())); self } #[allow(clippy::unused_self)] fn default_fields(&self) -> Vec<FieldConfig> { vec![FieldConfig::ID, FieldConfig::Chunk] } fn default_schema_from_fields(&self) -> Arc<Schema> { let mut fields = Vec::new(); let vector_size = self.vector_size; for field in self.fields.as_deref().unwrap_or(&self.default_fields()) { match field { FieldConfig::Vector(config) => { let vector_size = config.vector_size.or(vector_size.flatten()).expect( "Vector size should be set either in the field or in the LanceDB builder", ); fields.push(Field::new( config.field_name(), DataType::FixedSizeList( Arc::new(Field::new("item", DataType::Float32, true)), vector_size, ), true, )); } FieldConfig::Chunk => { fields.push(Field::new(field.field_name(), DataType::Utf8, false)); } FieldConfig::Metadata(_) => { fields.push(Field::new(field.field_name(), DataType::Utf8, true)); } FieldConfig::ID => { fields.push(Field::new( field.field_name(), DataType::FixedSizeList( Arc::new(Field::new("item", DataType::UInt8, true)), 16, ), false, )); } } } Arc::new(Schema::new(fields)) } fn default_connection_pool(&self) -> Result<Arc<LanceDBConnectionPool>> { let mgr = LanceDBPoolManager::builder() .uri(self.uri.clone().flatten().context("URI should be set")?) .api_key(self.api_key.clone().flatten()) .region(self.region.clone().flatten()) .storage_options(self.storage_options.clone().unwrap_or_default()) .build()?; LanceDBConnectionPool::builder(mgr) .max_size(self.pool_size.flatten().unwrap_or(10)) .build() .map(Arc::new) .map_err(Into::into) } } #[derive(Clone)] pub enum FieldConfig { Vector(VectorConfig), Metadata(MetadataConfig), Chunk, ID, } impl FieldConfig { pub fn field_name(&self) -> String { match self { FieldConfig::Vector(config) => config.field_name(), FieldConfig::Metadata(config) => config.field.clone(), FieldConfig::Chunk => "chunk".into(), FieldConfig::ID => "id".into(), } } } #[derive(Clone)] pub struct VectorConfig { embedded_field: EmbeddedField, vector_size: Option<i32>, } impl VectorConfig { pub fn field_name(&self) -> String { format!( "vector_{}", normalize_field_name(&self.embedded_field.to_string()) ) } } impl From<EmbeddedField> for VectorConfig { fn from(val: EmbeddedField) -> Self { VectorConfig { embedded_field: val, vector_size: None, } } } #[derive(Clone)] pub struct MetadataConfig { field: String, original_field: String, } impl<T: AsRef<str>> From<T> for MetadataConfig { fn from(val: T) -> Self { MetadataConfig { field: normalize_field_name(val.as_ref()), original_field: val.as_ref().to_string(), } } } pub(crate) fn normalize_field_name(field: &str) -> String { field .to_lowercase() .replace(|c: char| !c.is_alphanumeric(), "_") } ================================================ FILE: swiftide-integrations/src/lancedb/persist.rs ================================================ use std::sync::Arc; use anyhow::Context as _; use anyhow::Result; use arrow_array::Array; use arrow_array::FixedSizeListArray; use arrow_array::GenericByteArray; use arrow_array::RecordBatch; use arrow_array::RecordBatchIterator; use arrow_array::types::Float32Type; use arrow_array::types::UInt8Type; use arrow_array::types::Utf8Type; use async_trait::async_trait; use swiftide_core::Persist; use swiftide_core::indexing::IndexingStream; use swiftide_core::indexing::TextNode; use super::FieldConfig; use super::LanceDB; #[async_trait] impl Persist for LanceDB { type Input = String; type Output = String; #[tracing::instrument(skip_all)] async fn setup(&self) -> Result<()> { let conn = self.get_connection().await?; let schema = self.schema.clone(); if let Err(err) = conn.open_table(&self.table_name).execute().await { if matches!(err, lancedb::Error::TableNotFound { .. }) { conn.create_empty_table(&self.table_name, schema) .execute() .await .map(|_| ()) .map_err(anyhow::Error::from)?; } else { return Err(err.into()); } } Ok(()) } #[tracing::instrument(skip_all)] async fn store(&self, node: TextNode) -> Result<TextNode> { let mut nodes = vec![node; 1]; self.store_nodes(&nodes).await?; let node = nodes.swap_remove(0); Ok(node) } #[tracing::instrument(skip_all)] async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> { self.store_nodes(&nodes).await.map(|()| nodes).into() } fn batch_size(&self) -> Option<usize> { Some(self.batch_size) } } impl LanceDB { async fn store_nodes(&self, nodes: &[TextNode]) -> Result<()> { let schema = self.schema.clone(); let batches = self.extract_arrow_batches_from_nodes(nodes)?; let data = RecordBatchIterator::new( vec![ RecordBatch::try_new(schema.clone(), batches) .context("Could not create batches")?, ] .into_iter() .map(Ok), schema.clone(), ); let conn = self.get_connection().await?; let table = conn.open_table(&self.table_name).execute().await?; let mut merge_insert = table.merge_insert(&["id"]); merge_insert .when_matched_update_all(None) .when_not_matched_insert_all(); merge_insert.execute(Box::new(data)).await?; Ok(()) } fn extract_arrow_batches_from_nodes( &self, nodes: &[TextNode], ) -> core::result::Result<Vec<Arc<dyn Array>>, anyhow::Error> { let fields = self.fields.as_slice(); let mut batches: Vec<Arc<dyn Array>> = Vec::with_capacity(fields.len()); for field in fields { match field { FieldConfig::Vector(config) => { let mut row = Vec::with_capacity(nodes.len()); let vector_size = config .vector_size .or(self.vector_size) .context("Expected vector size to be set for field")?; for node in nodes { let data = node .vectors .as_ref() // TODO: verify compiler optimizes the double loops away .and_then(|v| v.get(&config.embedded_field)) .map(|v| v.iter().map(|f| Some(*f))); row.push(data); } batches.push(Arc::new(FixedSizeListArray::from_iter_primitive::< Float32Type, _, _, >(row, vector_size))); } FieldConfig::Metadata(config) => { let mut row = Vec::with_capacity(nodes.len()); for node in nodes { let data = node .metadata .get(&config.original_field) // TODO: Verify this gives the correct data .and_then(|v| v.as_str()); row.push(data); } batches.push(Arc::new(GenericByteArray::<Utf8Type>::from_iter(row))); } FieldConfig::Chunk => { let mut row = Vec::with_capacity(nodes.len()); for node in nodes { let data = Some(node.chunk.as_str()); row.push(data); } batches.push(Arc::new(GenericByteArray::<Utf8Type>::from_iter(row))); } FieldConfig::ID => { let mut row = Vec::with_capacity(nodes.len()); for node in nodes { let data = Some(node.id().as_bytes().map(Some)); row.push(data); } batches.push(Arc::new(FixedSizeListArray::from_iter_primitive::< UInt8Type, _, _, >(row, 16))); } } } Ok(batches) } } #[cfg(test)] mod test { use swiftide_core::{Persist as _, indexing::EmbeddedField}; use temp_dir::TempDir; use super::*; async fn setup() -> (TempDir, LanceDB) { let tempdir = TempDir::new().unwrap(); let lancedb = LanceDB::builder() .uri(tempdir.child("lancedb").to_str().unwrap()) .vector_size(384) .with_metadata("filter") .with_vector(EmbeddedField::Combined) .table_name("swiftide_test") .build() .unwrap(); lancedb.setup().await.unwrap(); (tempdir, lancedb) } #[tokio::test] async fn test_no_error_when_table_exists() { let (_guard, lancedb) = setup().await; lancedb .setup() .await .expect("Should not error if table exists"); } } ================================================ FILE: swiftide-integrations/src/lancedb/retrieve.rs ================================================ use anyhow::Result; use arrow_array::{RecordBatch, StringArray}; use async_trait::async_trait; use futures_util::TryStreamExt; use itertools::Itertools; use lancedb::query::{ExecutableQuery, QueryBase}; use swiftide_core::{ Retrieve, document::Document, indexing::Metadata, querying::{ Query, search_strategies::{CustomStrategy, SimilaritySingleEmbedding}, states, }, }; use super::{FieldConfig, LanceDB}; /// Implement the `Retrieve` trait for `SimilaritySingleEmbedding` search strategy. /// /// Can be used in the query pipeline to retrieve documents from `LanceDB`. /// /// Supports filters as strings. Refer to the `LanceDB` documentation for the format. #[async_trait] impl Retrieve<SimilaritySingleEmbedding<String>> for LanceDB { #[tracing::instrument] async fn retrieve( &self, search_strategy: &SimilaritySingleEmbedding<String>, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { let Some(embedding) = &query.embedding else { anyhow::bail!("No embedding for query") }; let table = self .get_connection() .await? .open_table(&self.table_name) .execute() .await?; let vector_fields = self .fields .iter() .filter(|field| matches!(field, FieldConfig::Vector(_))) .collect_vec(); if vector_fields.is_empty() || vector_fields.len() > 1 { anyhow::bail!("Zero or multiple vector fields configured in schema") } let column_name = vector_fields.first().map(|v| v.field_name()).unwrap(); let mut query_builder = table .query() .nearest_to(embedding.as_slice())? .column(&column_name) .limit(usize::try_from(search_strategy.top_k())?); if let Some(filter) = &search_strategy.filter() { query_builder = query_builder.only_if(filter); } let batches = query_builder .execute() .await? .try_collect::<Vec<_>>() .await?; let documents = Self::retrieve_from_record_batches(batches.as_slice()); Ok(query.retrieved_documents(documents)) } } #[async_trait] impl Retrieve<SimilaritySingleEmbedding> for LanceDB { async fn retrieve( &self, search_strategy: &SimilaritySingleEmbedding, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { Retrieve::<SimilaritySingleEmbedding<String>>::retrieve( self, &search_strategy.into_concrete_filter::<String>(), query, ) .await } } #[async_trait] impl<Q: ExecutableQuery + Send + Sync + 'static> Retrieve<CustomStrategy<Q>> for LanceDB { /// Implements vector similarity search for `LanceDB` using a custom query strategy. /// /// # Type Parameters /// * `VectorQuery` - `LanceDB`'s query type for vector similarity search async fn retrieve( &self, search_strategy: &CustomStrategy<Q>, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { // Build the custom query using both strategy and query state let query_builder = search_strategy.build_query(&query).await?; // Execute the query using the builder's built-in methods let batches = query_builder .execute() .await? .try_collect::<Vec<_>>() .await?; let documents = Self::retrieve_from_record_batches(batches.as_slice()); Ok(query.retrieved_documents(documents)) } } impl LanceDB { /// Retrieves documents from Arrow `RecordBatches` by processing each row and extracting content /// and metadata fields. /// /// The function expects a "chunk" field to contain the main document content, while all other /// string fields are treated as metadata. Non-string fields are currently skipped fn retrieve_from_record_batches(batches: &[RecordBatch]) -> Vec<Document> { let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum(); let mut documents = Vec::with_capacity(total_rows); let process_batch = |batch: &RecordBatch, documents: &mut Vec<Document>| { for row_idx in 0..batch.num_rows() { let schema = batch.schema(); let (content, metadata): (String, Option<Metadata>) = { let mut metadata = Metadata::default(); let mut content = String::new(); for (col_idx, field) in schema.as_ref().fields().iter().enumerate() { if let Some(array) = batch.column(col_idx).as_any().downcast_ref::<StringArray>() { let value = array.value(row_idx).to_string(); if field.name() == "chunk" { content = value; } else { metadata.insert(field.name().clone(), value); } } else { // Handle other array types as necessary // TODO: Can't we just downcast to serde::Value or fail? } } ( content, if metadata.is_empty() { None } else { Some(metadata) }, ) }; documents.push(Document::new(content, metadata)); } }; for batch in batches { process_batch(batch, &mut documents); } documents } } #[cfg(test)] mod test { use swiftide_core::{ Persist as _, indexing::{self, EmbeddedField}, }; use temp_dir::TempDir; use super::*; async fn setup() -> (TempDir, LanceDB) { let tempdir = TempDir::new().unwrap(); let lancedb = LanceDB::builder() .uri(tempdir.child("lancedb").to_str().unwrap()) .vector_size(384) .with_metadata("filter") .with_vector(EmbeddedField::Combined) .table_name("swiftide_test") .build() .unwrap(); lancedb.setup().await.unwrap(); (tempdir, lancedb) } #[tokio::test] async fn test_retrieve_multiple_docs_and_filter() { let (_guard, lancedb) = setup().await; let nodes = vec![ indexing::TextNode::new("test_query1").with_metadata(("filter", "true")), indexing::TextNode::new("test_query2").with_metadata(("filter", "true")), indexing::TextNode::new("test_query3").with_metadata(("filter", "false")), ] .into_iter() .map(|node| { node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); node.to_owned() }) .collect(); lancedb .batch_store(nodes) .await .try_collect::<Vec<_>>() .await .unwrap(); let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); let search_strategy = SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string()); let result = lancedb .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 2); let search_strategy = SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string()); let result = lancedb .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 0); let search_strategy = SimilaritySingleEmbedding::<()>::default(); let result = lancedb .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 3); } } ================================================ FILE: swiftide-integrations/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] //! Integrations with various platforms and external services. #[cfg(feature = "anthropic")] pub mod anthropic; #[cfg(feature = "aws-bedrock")] pub mod aws_bedrock_v2; #[cfg(feature = "dashscope")] pub mod dashscope; #[cfg(feature = "duckdb")] pub mod duckdb; #[cfg(feature = "fastembed")] pub mod fastembed; #[cfg(feature = "fluvio")] pub mod fluvio; #[cfg(feature = "gemini")] pub mod gemini; #[cfg(feature = "groq")] pub mod groq; #[cfg(feature = "kafka")] pub mod kafka; #[cfg(feature = "lancedb")] pub mod lancedb; #[cfg(feature = "ollama")] pub mod ollama; #[cfg(feature = "open-router")] pub mod open_router; #[cfg(feature = "openai")] pub mod openai; #[cfg(feature = "parquet")] pub mod parquet; #[cfg(feature = "pgvector")] pub mod pgvector; #[cfg(feature = "qdrant")] pub mod qdrant; #[cfg(feature = "redb")] pub mod redb; #[cfg(feature = "redis")] pub mod redis; #[cfg(feature = "scraping")] pub mod scraping; #[cfg(feature = "tiktoken")] pub mod tiktoken; #[cfg(feature = "tree-sitter")] pub mod treesitter; ================================================ FILE: swiftide-integrations/src/ollama/config.rs ================================================ use derive_builder::Builder; use reqwest::header::HeaderMap; use secrecy::SecretString; use serde::Deserialize; const OLLAMA_API_BASE: &str = "http://localhost:11434/v1"; #[derive(Clone, Debug, Deserialize, Builder)] #[serde(default)] pub struct OllamaConfig { api_base: String, api_key: SecretString, } impl OllamaConfig { pub fn builder() -> OllamaConfigBuilder { OllamaConfigBuilder::default() } pub fn with_api_base(&mut self, api_base: &str) -> &mut Self { self.api_base = api_base.to_string(); self } } impl Default for OllamaConfig { fn default() -> Self { Self { api_base: OLLAMA_API_BASE.to_string(), api_key: String::new().into(), } } } impl async_openai::config::Config for OllamaConfig { fn headers(&self) -> HeaderMap { HeaderMap::new() } fn url(&self, path: &str) -> String { format!("{}{}", self.api_base, path) } fn api_base(&self) -> &str { &self.api_base } fn api_key(&self) -> &SecretString { &self.api_key } fn query(&self) -> Vec<(&str, &str)> { vec![] } } ================================================ FILE: swiftide-integrations/src/ollama/mod.rs ================================================ //! This module provides integration with `Ollama`'s API, enabling the use of language models and //! embeddings within the Swiftide project. It includes the `Ollama` struct for managing API clients //! and default options for embedding and prompt models. The module is conditionally compiled based //! on the "ollama" feature flag. use config::OllamaConfig; use crate::openai; pub mod config; /// The `Ollama` struct encapsulates an `Ollama` client and default options for embedding and prompt /// models. It uses the `Builder` pattern for flexible and customizable instantiation. /// /// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that either a prompt /// model or embedding model always need to be set, either with /// [`Ollama::with_default_prompt_model`] or [`Ollama::with_default_embed_model`] or via the /// builder. You can find available models in the Ollama documentation. /// /// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means /// some features might not work as expected. See the Ollama documentation for details. pub type Ollama = openai::GenericOpenAI<OllamaConfig>; pub type OllamaBuilder = openai::GenericOpenAIBuilder<OllamaConfig>; pub type OllamaBuilderError = openai::GenericOpenAIBuilderError; pub use openai::{Options, OptionsBuilder, OptionsBuilderError}; impl Ollama { /// Build a new `Ollama` instance pub fn builder() -> OllamaBuilder { OllamaBuilder::default() } } impl Default for Ollama { fn default() -> Self { Self::builder().build().unwrap() } } #[cfg(test)] mod test { use super::*; #[test] fn test_default_prompt_model() { let openai = Ollama::builder() .default_prompt_model("llama3.1") .build() .unwrap(); assert_eq!( openai.default_options.prompt_model, Some("llama3.1".to_string()) ); } #[test] fn test_default_embed_model() { let ollama = Ollama::builder() .default_embed_model("mxbai-embed-large") .build() .unwrap(); assert_eq!( ollama.default_options.embed_model, Some("mxbai-embed-large".to_string()) ); } #[test] fn test_default_models() { let ollama = Ollama::builder() .default_embed_model("mxbai-embed-large") .default_prompt_model("llama3.1") .build() .unwrap(); assert_eq!( ollama.default_options.embed_model, Some("mxbai-embed-large".to_string()) ); assert_eq!( ollama.default_options.prompt_model, Some("llama3.1".to_string()) ); } #[test] fn test_building_via_default_prompt_model() { let mut client = Ollama::default(); assert!(client.default_options.prompt_model.is_none()); client.with_default_prompt_model("llama3.1"); assert_eq!( client.default_options.prompt_model, Some("llama3.1".to_string()) ); } #[test] fn test_building_via_default_embed_model() { let mut client = Ollama::default(); assert!(client.default_options.embed_model.is_none()); client.with_default_embed_model("mxbai-embed-large"); assert_eq!( client.default_options.embed_model, Some("mxbai-embed-large".to_string()) ); } #[test] fn test_building_via_default_models() { let mut client = Ollama::default(); assert!(client.default_options.embed_model.is_none()); client.with_default_prompt_model("llama3.1"); client.with_default_embed_model("mxbai-embed-large"); assert_eq!( client.default_options.prompt_model, Some("llama3.1".to_string()) ); assert_eq!( client.default_options.embed_model, Some("mxbai-embed-large".to_string()) ); } } ================================================ FILE: swiftide-integrations/src/open_router/config.rs ================================================ use derive_builder::Builder; use reqwest::header::{AUTHORIZATION, HeaderMap}; use secrecy::{ExposeSecret as _, SecretString}; use serde::Deserialize; const OPENROUTER_API_BASE: &str = "https://openrouter.ai/api/v1"; #[derive(Clone, Debug, Deserialize, Builder)] #[serde(default)] #[builder(setter(into, strip_option))] pub struct OpenRouterConfig { #[builder(default = OPENROUTER_API_BASE.to_string())] api_base: String, api_key: SecretString, /// Sets the HTTP-Referer header (leaderbord) site_url: Option<String>, /// Sets the name (leaderbord) site_name: Option<String>, } impl OpenRouterConfig { pub fn builder() -> OpenRouterConfigBuilder { OpenRouterConfigBuilder::default() } pub fn with_api_base(&mut self, api_base: &str) -> &mut Self { self.api_base = api_base.to_string(); self } pub fn with_api_key(&mut self, api_key: impl Into<SecretString>) -> &mut Self { self.api_key = api_key.into(); self } pub fn with_site_url(&mut self, site_url: &str) -> &mut Self { self.site_url = Some(site_url.to_string()); self } pub fn with_site_name(&mut self, site_name: &str) -> &mut Self { self.site_name = Some(site_name.to_string()); self } } impl Default for OpenRouterConfig { fn default() -> Self { Self { api_base: OPENROUTER_API_BASE.to_string(), api_key: std::env::var("OPENROUTER_API_KEY") .unwrap_or_else(|_| String::new()) .into(), site_url: None, site_name: None, } } } impl async_openai::config::Config for OpenRouterConfig { fn headers(&self) -> HeaderMap { let mut headers = HeaderMap::new(); let api_key = self.api_key.expose_secret(); assert!(!api_key.is_empty(), "API key for OpenRouter is required"); headers.insert( AUTHORIZATION, format!("Bearer {}", self.api_key.expose_secret()) .as_str() .parse() .unwrap(), ); if let Ok(site_url) = self .site_url .as_deref() .unwrap_or("https://github.com/bosun-ai/swiftide") .parse() { headers.insert("HTTP-Referer", site_url); } if let Ok(site_name) = self.site_url.as_deref().unwrap_or("Swiftide").parse() { headers.insert("X-Title", site_name); } headers } fn url(&self, path: &str) -> String { format!("{}{}", self.api_base, path) } fn api_base(&self) -> &str { &self.api_base } fn api_key(&self) -> &SecretString { &self.api_key } fn query(&self) -> Vec<(&str, &str)> { vec![] } } ================================================ FILE: swiftide-integrations/src/open_router/mod.rs ================================================ //! This module provides integration with `OpenRouter`'s API, enabling the use of language models //! and embeddings within the Swiftide project. It includes the `OpenRouter` struct for managing API //! clients and default options for embedding and prompt models. The module is conditionally //! compiled based on the "openrouter" feature flag. use config::OpenRouterConfig; use crate::openai; pub mod config; /// The `OpenRouter` struct encapsulates an `OpenRouter` client and default options for embedding /// and prompt models. It uses the `Builder` pattern for flexible and customizable instantiation. /// /// By default it will look for a `OPENROUTER_API_KEY` environment variable. Note that either a /// prompt model or embedding model always need to be set, either with /// [`OpenRouter::with_default_prompt_model`] or [`OpenRouter::with_default_embed_model`] or via the /// builder. You can find available models in the `OpenRouter` documentation. /// /// Under the hood it uses [`async_openai`], with the `OpenRouter` openai compatible api. This means /// some features might not work as expected. See the `OpenRouter` documentation for details. pub type OpenRouter = openai::GenericOpenAI<OpenRouterConfig>; pub type OpenRouterBuilder = openai::GenericOpenAIBuilder<OpenRouterConfig>; pub type OpenRouterBuilderError = openai::GenericOpenAIBuilderError; pub use openai::{Options, OptionsBuilder, OptionsBuilderError}; impl OpenRouter { /// Creates a new `OpenRouterBuilder` for constructing `OpenRouter` instances. pub fn builder() -> OpenRouterBuilder { OpenRouterBuilder::default() } } impl Default for OpenRouter { fn default() -> Self { Self::builder().build().unwrap() } } #[cfg(test)] mod test { use super::*; #[test] fn test_default_prompt_model() { let openai = OpenRouter::builder() .default_prompt_model("llama3.1") .build() .unwrap(); assert_eq!( openai.default_options.prompt_model, Some("llama3.1".to_string()) ); } #[test] fn test_default_models() { let openrouter = OpenRouter::builder() .default_prompt_model("llama3.1") .build() .unwrap(); assert_eq!( openrouter.default_options.prompt_model, Some("llama3.1".to_string()) ); } #[test] fn test_building_via_default_prompt_model() { let mut client = OpenRouter::default(); assert!(client.default_options.prompt_model.is_none()); client.with_default_prompt_model("llama3.1"); assert_eq!( client.default_options.prompt_model, Some("llama3.1".to_string()) ); } } ================================================ FILE: swiftide-integrations/src/openai/chat_completion.rs ================================================ use anyhow::{Context as _, Result}; use async_openai::types::chat::{ ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart, ChatCompletionStreamOptions, ChatCompletionToolChoiceOption, ChatCompletionTools, FunctionCall, FunctionObject, ImageUrl, InputAudio, InputAudioFormat, ToolChoiceOptions, }; use async_trait::async_trait; use base64::Engine as _; use futures_util::StreamExt as _; use futures_util::stream; use itertools::Itertools; use serde::Serialize; use swiftide_core::ChatCompletionStream; use swiftide_core::chat_completion::Usage; use swiftide_core::chat_completion::{ ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatMessageContentPart, ChatMessageContentSource, ToolCall, ToolSpec, errors::LanguageModelError, }; #[cfg(feature = "metrics")] use swiftide_core::metrics::emit_usage; use super::GenericOpenAI; use super::openai_error_to_language_model_error; use super::responses_api::{ build_responses_request_from_chat, response_to_chat_completion, responses_stream_adapter, }; use super::tool_schema::OpenAiToolSchema; use tracing_futures::Instrument; #[async_trait] impl< C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone + 'static, > ChatCompletion for GenericOpenAI<C> { #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))] #[cfg_attr( feature = "langfuse", tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION")) )] async fn complete( &self, request: &ChatCompletionRequest<'_>, ) -> Result<ChatCompletionResponse, LanguageModelError> { if self.is_responses_api_enabled() { return self.complete_via_responses_api(request).await; } let model = self .default_options .prompt_model .as_ref() .context("Model not set")?; let messages = request .messages() .iter() .filter_map(|message| message_to_openai(message).transpose()) .collect::<Result<Vec<_>>>()?; // Build the request to be sent to the OpenAI API. let mut openai_request = self .chat_completion_request_defaults() .model(model) .messages(messages) .to_owned(); if !request.tools_spec().is_empty() { openai_request .tools( request .tools_spec() .iter() .map(tools_to_openai) .collect::<Result<Vec<_>>>()?, ) .tool_choice(ChatCompletionToolChoiceOption::Mode( ToolChoiceOptions::Auto, )); if let Some(par) = self.default_options.parallel_tool_calls { openai_request.parallel_tool_calls(par); } } let openai_request = openai_request .build() .map_err(openai_error_to_language_model_error)?; tracing::trace!(model, request = ?request, "Sending request to OpenAI"); let tracking_request = openai_request.clone(); let response = self .client .chat() .create(openai_request) .await .map_err(openai_error_to_language_model_error)?; tracing::trace!(?response, "[ChatCompletion] Full response from OpenAI"); // Make sure the debug log is a concise one line let mut builder = ChatCompletionResponse::builder() .maybe_message( response .choices .first() .and_then(|choice| choice.message.content.clone()), ) .maybe_tool_calls( response .choices .first() .and_then(|choice| choice.message.tool_calls.as_ref()) .map(|tool_calls| { tool_calls .iter() .filter_map(|tool_call| match tool_call { ChatCompletionMessageToolCalls::Function(call) => Some( ToolCall::builder() .id(call.id.clone()) .args(call.function.arguments.clone()) .name(call.function.name.clone()) .build() .expect("infallible"), ), ChatCompletionMessageToolCalls::Custom(_) => None, }) .collect_vec() }), ) .to_owned(); if let Some(usage) = &response.usage { builder.usage(Usage::from(usage)); } let our_response = builder.build().map_err(LanguageModelError::from)?; self.track_completion( model, our_response.usage.as_ref(), Some(&tracking_request), Some(&our_response), ); Ok(our_response) } #[tracing::instrument(skip_all)] async fn complete_stream(&self, request: &ChatCompletionRequest<'_>) -> ChatCompletionStream { if self.is_responses_api_enabled() { return self.complete_stream_via_responses_api(request).await; } let Some(model_name) = self.default_options.prompt_model.clone() else { return LanguageModelError::permanent("Model not set").into(); }; #[cfg(not(any(feature = "metrics", feature = "langfuse")))] let _ = &model_name; let messages = match request .messages() .iter() .filter_map(|message| message_to_openai(message).transpose()) .collect::<Result<Vec<_>>>() { Ok(messages) => messages, Err(e) => return LanguageModelError::from(e).into(), }; // Build the request to be sent to the OpenAI API. let mut openai_request = self .chat_completion_request_defaults() .model(&model_name) .messages(messages) .stream(true) .stream_options(ChatCompletionStreamOptions { include_usage: Some(true), include_obfuscation: None, }) .to_owned(); if !request.tools_spec().is_empty() { openai_request .tools( match request .tools_spec() .iter() .map(tools_to_openai) .collect::<Result<Vec<_>>>() { Ok(tools) => tools, Err(e) => { return LanguageModelError::from(e).into(); } }, ) .tool_choice(ChatCompletionToolChoiceOption::Mode( ToolChoiceOptions::Auto, )); if let Some(par) = self.default_options.parallel_tool_calls { openai_request.parallel_tool_calls(par); } } let openai_request = match openai_request.build() { Ok(request) => request, Err(e) => { return openai_error_to_language_model_error(e).into(); } }; tracing::trace!(model = %model_name, request = ?request, "Sending request to OpenAI"); let response_stream = match self .client .chat() .create_stream(openai_request.clone()) .await { Ok(response) => response, Err(e) => return openai_error_to_language_model_error(e).into(), }; let stream_full = self.stream_full; let model_name_for_track = model_name.clone(); let self_for_stream = self.clone(); let tracking_request = openai_request; let span = if cfg!(feature = "langfuse") { tracing::info_span!("stream", langfuse.type = "GENERATION") } else { tracing::info_span!("stream") }; let stream = stream::unfold( ( response_stream, ChatCompletionResponse::default(), tracking_request.clone(), false, // finished ), move |(mut response_stream, mut state, tracking_request, finished)| { let stream_full = stream_full; let self_for_stream = self_for_stream.clone(); let model_name_for_track = model_name_for_track.clone(); async move { if finished { return None; } match response_stream.next().await { Some(Ok(chunk)) => { let delta_message = chunk .choices .first() .and_then(|d| d.delta.content.as_deref()); let delta_tool_calls = chunk .choices .first() .and_then(|d| d.delta.tool_calls.as_deref()); let usage = chunk.usage.as_ref(); state.append_message_delta(delta_message); if let Some(delta_tool_calls) = delta_tool_calls { for tc in delta_tool_calls { state.append_tool_call_delta( tc.index as usize, tc.id.as_deref(), tc.function.as_ref().and_then(|f| f.name.as_deref()), tc.function.as_ref().and_then(|f| f.arguments.as_deref()), ); } } if let Some(usage) = usage { let usage = Usage::from(usage); state.append_usage_delta( usage.prompt_tokens, usage.completion_tokens, usage.total_tokens, ); } let snapshot = if stream_full { state.clone() } else { ChatCompletionResponse { id: state.id, message: None, tool_calls: None, usage: None, reasoning: None, delta: state.delta.clone(), } }; Some(( Ok(snapshot), (response_stream, state, tracking_request, false), )) } Some(Err(err)) => Some(( Err(openai_error_to_language_model_error(err)), (response_stream, state, tracking_request, true), )), None => { // Final emission; track completion with the full state. self_for_stream.track_completion( &model_name_for_track, state.usage.as_ref(), Some(&tracking_request), Some(&state), ); let final_snapshot = state.clone(); Some(( Ok(final_snapshot), (response_stream, state, tracking_request, true), )) } } } }, ); Box::pin(tracing_futures::Instrument::instrument(stream, span)) } } impl< C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone + 'static, > GenericOpenAI<C> { async fn complete_via_responses_api( &self, request: &ChatCompletionRequest<'_>, ) -> Result<ChatCompletionResponse, LanguageModelError> { let model = self .default_options .prompt_model .as_ref() .context("Model not set")?; let create_request = build_responses_request_from_chat(self, request)?; let tracking_request = create_request.clone(); let response = self .client .responses() .create(create_request) .await .map_err(openai_error_to_language_model_error)?; let completion = response_to_chat_completion(&response)?; self.track_completion( model, completion.usage.as_ref(), Some(&tracking_request), Some(&completion), ); Ok(completion) } #[allow(clippy::too_many_lines)] async fn complete_stream_via_responses_api( &self, request: &ChatCompletionRequest<'_>, ) -> ChatCompletionStream { #[allow(unused_variables)] let Some(model_name) = self.default_options.prompt_model.clone() else { return LanguageModelError::permanent("Model not set").into(); }; let mut create_request = match build_responses_request_from_chat(self, request) { Ok(req) => req, Err(err) => return err.into(), }; create_request.stream = Some(true); let stream = match self .client .responses() .create_stream(create_request.clone()) .await { Ok(stream) => stream, Err(err) => return openai_error_to_language_model_error(err).into(), }; let stream_full = self.stream_full; let span = if cfg!(feature = "langfuse") { tracing::info_span!("responses_stream", langfuse.type = "GENERATION") } else { tracing::info_span!("responses_stream") }; let mapped_stream = responses_stream_adapter(stream, stream_full); let this = self.clone(); let tracked_request = create_request; let mapped_stream = mapped_stream.map(move |result| match result { Ok(item) => { if item.finished { this.track_completion( &model_name, item.response.usage.as_ref(), Some(&tracked_request), Some(&item.response), ); } Ok(item.response) } Err(err) => Err(err), }); Box::pin(Instrument::instrument(mapped_stream, span)) } #[allow(unused_variables)] pub(crate) fn track_completion<R, S>( &self, model: &str, usage: Option<&Usage>, request: Option<&R>, response: Option<&S>, ) where R: Serialize + ?Sized, S: Serialize + ?Sized, { if let Some(usage) = usage { let cb_usage = usage.clone(); if let Some(callback) = &self.on_usage { let callback = callback.clone(); tokio::spawn(async move { if let Err(err) = callback(&cb_usage).await { tracing::error!("Error in on_usage callback: {err}"); } }); } #[cfg(feature = "metrics")] emit_usage( model, usage.prompt_tokens.into(), usage.completion_tokens.into(), usage.total_tokens.into(), self.metric_metadata.as_ref(), ); } #[cfg(feature = "langfuse")] tracing::debug!( langfuse.model = model, langfuse.input = request.and_then(langfuse_json_redacted).unwrap_or_default(), langfuse.output = response.and_then(langfuse_json).unwrap_or_default(), langfuse.usage = usage.and_then(langfuse_json).unwrap_or_default(), ); } } #[cfg(feature = "langfuse")] pub(crate) fn langfuse_json<T: Serialize + ?Sized>(value: &T) -> Option<String> { serde_json::to_string_pretty(value).ok() } #[cfg(feature = "langfuse")] pub(crate) fn langfuse_json_redacted<T: Serialize + ?Sized>(value: &T) -> Option<String> { let mut value = serde_json::to_value(value).ok()?; redact_image_urls(&mut value); serde_json::to_string_pretty(&value).ok() } #[cfg(feature = "langfuse")] fn redact_image_urls(value: &mut serde_json::Value) { match value { serde_json::Value::Object(map) => { if let Some(image_url) = map.get_mut("image_url") && let serde_json::Value::Object(image_obj) = image_url && let Some(serde_json::Value::String(url)) = image_obj.get_mut("url") && let Some(truncated) = truncate_data_url(url) { *url = truncated; } for val in map.values_mut() { redact_image_urls(val); } } serde_json::Value::Array(arr) => { for val in arr { redact_image_urls(val); } } _ => {} } } #[cfg(feature = "langfuse")] fn truncate_data_url(url: &str) -> Option<String> { const MAX_DATA_PREVIEW: usize = 32; if !url.starts_with("data:") { return None; } let (prefix, data) = url.split_once(',')?; if data.len() <= MAX_DATA_PREVIEW { return None; } let preview = &data[..MAX_DATA_PREVIEW]; let truncated = data.len() - MAX_DATA_PREVIEW; Some(format!( "{prefix},{preview}...[truncated {truncated} chars]" )) } #[cfg(not(feature = "langfuse"))] #[allow(dead_code)] pub(crate) fn langfuse_json<T>(_value: &T) -> Option<String> { None } fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTools> { let parameters = OpenAiToolSchema::try_from(spec) .context("tool schema must be OpenAI compatible")? .into_value(); let function = FunctionObject { name: spec.name.clone(), description: Some(spec.description.clone()), parameters: Some(parameters), strict: Some(true), }; Ok(ChatCompletionTools::Function( async_openai::types::chat::ChatCompletionTool { function }, )) } fn message_to_openai( message: &ChatMessage, ) -> Result<Option<async_openai::types::chat::ChatCompletionRequestMessage>> { let openai_message = match message { ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default() .content(msg.as_ref()) .build()? .into(), ChatMessage::UserWithParts(parts) => ChatCompletionRequestUserMessageArgs::default() .content(user_parts_to_openai(parts)?) .build()? .into(), ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default() .content(msg.as_ref()) .build()? .into(), ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default() .content(msg.as_ref()) .build()? .into(), ChatMessage::ToolOutput(tool_call, tool_output) => { let Some(content) = tool_output.content() else { return Ok(Some( ChatCompletionRequestToolMessageArgs::default() .tool_call_id(tool_call.id()) .build()? .into(), )); }; ChatCompletionRequestToolMessageArgs::default() .content(content) .tool_call_id(tool_call.id()) .build()? .into() } ChatMessage::Assistant(content, tool_calls) => { let mut builder = ChatCompletionRequestAssistantMessageArgs::default(); let has_tool_calls = tool_calls.as_ref().is_some_and(|calls| !calls.is_empty()); if let Some(content) = content.as_deref() { builder.content(content); } if let Some(tool_calls) = tool_calls.as_ref() { let calls = tool_calls .iter() .map(|tool_call| { ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall { id: tool_call.id().to_string(), function: FunctionCall { name: tool_call.name().to_string(), arguments: tool_call.args().unwrap_or_default().to_string(), }, }) }) .collect::<Vec<_>>(); builder.tool_calls(calls); } if content.is_none() && !has_tool_calls { return Ok(None); } builder.build()?.into() } ChatMessage::Reasoning(_) => return Ok(None), }; Ok(Some(openai_message)) } fn user_parts_to_openai( parts: &[ChatMessageContentPart], ) -> Result<ChatCompletionRequestUserMessageContent> { let mapped = parts .iter() .map(part_to_openai_user_content_part) .collect::<Result<Vec<_>>>()?; Ok(ChatCompletionRequestUserMessageContent::Array(mapped)) } fn part_to_openai_user_content_part( part: &ChatMessageContentPart, ) -> Result<ChatCompletionRequestUserMessageContentPart> { Ok(match part { ChatMessageContentPart::Text { text } => ChatCompletionRequestUserMessageContentPart::from( ChatCompletionRequestMessageContentPartText::from(text.as_ref()), ), ChatMessageContentPart::Image { source, .. } => { let image_url = ImageUrl { url: source_to_openai_url(source)?, detail: None, }; ChatCompletionRequestUserMessageContentPart::from( ChatCompletionRequestMessageContentPartImage { image_url }, ) } ChatMessageContentPart::Audio { source, format } => { let ChatMessageContentSource::Bytes { data, .. } = source else { anyhow::bail!("OpenAI chat input_audio only supports bytes sources"); }; let format = match format.as_deref() { Some("wav") => InputAudioFormat::Wav, Some("mp3") | None => InputAudioFormat::Mp3, Some(other) => anyhow::bail!("Unsupported OpenAI chat input_audio format: {other}"), }; let input_audio = InputAudio { data: base64::engine::general_purpose::STANDARD.encode(data), format, }; ChatCompletionRequestUserMessageContentPart::from( ChatCompletionRequestMessageContentPartAudio { input_audio }, ) } ChatMessageContentPart::Document { .. } => { anyhow::bail!("OpenAI chat file parts are not supported by async-openai yet") } ChatMessageContentPart::Video { .. } => { anyhow::bail!("OpenAI chat completion does not support video parts") } }) } fn source_to_openai_url(source: &ChatMessageContentSource) -> Result<String> { match source { ChatMessageContentSource::Url { url } => Ok(url.clone()), ChatMessageContentSource::FileId { .. } => { anyhow::bail!("OpenAI chat image_url does not accept file_id sources") } ChatMessageContentSource::S3 { .. } => { anyhow::bail!("OpenAI chat image_url does not accept s3 sources") } ChatMessageContentSource::Bytes { data, media_type } => { let media_type = media_type.as_deref().unwrap_or("application/octet-stream"); let encoded = base64::engine::general_purpose::STANDARD.encode(data); Ok(format!("data:{media_type};base64,{encoded}")) } } } #[cfg(test)] mod tests { use crate::openai::{OpenAI, Options}; use super::*; use futures_util::StreamExt; use serde_json::json; use std::sync::Arc; use swiftide_core::chat_completion::{ToolCallBuilder, ToolOutput, UsageBuilder}; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; #[allow(dead_code)] #[derive(schemars::JsonSchema)] struct WeatherArgs { _city: String, } #[allow(dead_code)] #[derive(schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentArgs { request: NestedCommentRequest, } #[allow(dead_code)] #[derive(schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentRequest { #[serde(default, skip_serializing_if = "Option::is_none")] body: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] text: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] page_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] block_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] discussion_id: Option<String>, } #[test] fn test_tools_to_openai_sets_additional_properties_false() { let spec = ToolSpec::builder() .name("get_weather") .description("Retrieve weather data") .parameters_schema(schemars::schema_for!(WeatherArgs)) .build() .unwrap(); let tool = tools_to_openai(&spec).expect("tool conversion succeeds"); let function = match tool { ChatCompletionTools::Function(ref tool) => &tool.function, ChatCompletionTools::Custom(_) => panic!("expected function tool"), }; let additional_properties = function .parameters .as_ref() .and_then(serde_json::Value::as_object) .and_then(|obj| obj.get("additionalProperties")) .cloned(); assert_eq!( additional_properties, Some(serde_json::Value::Bool(false)), "Chat Completions require additionalProperties=false for tool parameters, got {}", serde_json::to_string_pretty(&function.parameters).unwrap() ); } #[test] fn test_tools_to_openai_sets_nested_required_for_typed_request_objects() { let spec = ToolSpec::builder() .name("notion_create_comment") .description("Create a comment") .parameters_schema(schemars::schema_for!(NestedCommentArgs)) .build() .unwrap(); let tool = tools_to_openai(&spec).expect("tool conversion succeeds"); let function = match tool { ChatCompletionTools::Function(ref tool) => &tool.function, ChatCompletionTools::Custom(_) => panic!("expected function tool"), }; let nested_required = function.parameters.as_ref().and_then(|schema| { let request_schema = schema .get("properties") .and_then(|value| value.get("request")) .and_then(serde_json::Value::as_object)?; let referenced_required = request_schema .get("$ref") .and_then(serde_json::Value::as_str) .and_then(|reference| reference.strip_prefix("#/$defs/")) .and_then(|definition_name| { schema .get("$defs") .and_then(|value| value.get(definition_name)) }) .and_then(|value| value.get("required")) .and_then(serde_json::Value::as_array); referenced_required.or_else(|| { request_schema .get("required") .and_then(serde_json::Value::as_array) }) }); let nested_required = nested_required.expect("nested request should have required"); let names: std::collections::HashSet<_> = nested_required .iter() .filter_map(serde_json::Value::as_str) .collect(); assert!(names.contains("body")); assert!(names.contains("text")); assert!(names.contains("page_id")); assert!(names.contains("block_id")); assert!(names.contains("discussion_id")); } #[test] fn test_message_to_openai_with_image_parts() { let message = ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::text("Describe this image."), ChatMessageContentPart::image("https://example.com/image.png"), ]); let openai_message = message_to_openai(&message) .expect("message conversion succeeds") .expect("message present"); let value = serde_json::to_value(openai_message).expect("serialize message"); let content = value .get("content") .and_then(serde_json::Value::as_array) .expect("content array"); assert_eq!(content[0]["type"], "text"); assert_eq!(content[0]["text"], "Describe this image."); assert_eq!(content[1]["type"], "image_url"); assert_eq!( content[1]["image_url"]["url"], "https://example.com/image.png" ); assert!(content[1]["image_url"]["detail"].is_null()); } #[test] fn test_message_to_openai_with_image_bytes_source() { let message = ChatMessage::new_user_with_parts(vec![ ChatMessageContentPart::text("Describe this image."), ChatMessageContentPart::Image { source: ChatMessageContentSource::bytes( vec![0_u8, 1_u8, 2_u8], Some("image/png".to_string()), ), format: None, }, ]); let openai_message = message_to_openai(&message) .expect("message conversion succeeds") .expect("message present"); let value = serde_json::to_value(openai_message).expect("serialize message"); let content = value .get("content") .and_then(serde_json::Value::as_array) .expect("content array"); let image_url = content[1]["image_url"]["url"] .as_str() .expect("image_url must be string"); assert!(image_url.starts_with("data:image/png;base64,")); } #[test_log::test(tokio::test)] async fn test_complete() { let mock_server = MockServer::start().await; // Mock OpenAI API response let response_body = json!({ "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT", "object": "chat.completion", "created": 123, "model": "gpt-4o", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "Hello, world!", "refusal": null, "annotations": [] }, "logprobs": null, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29, "prompt_tokens_details": { "cached_tokens": 0, "audio_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "service_tier": "default" }); Mock::given(method("POST")) .and(path("/chat/completions")) .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) .mount(&mock_server) .await; // Create a GenericOpenAI instance with the mock server URL let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let async_openai = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(async_openai) .default_prompt_model("gpt-4o") .build() .expect("Can create OpenAI client."); // Prepare a test request let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Hi".into())]) .build() .unwrap(); // Call the `complete` method let response = openai.complete(&request).await.unwrap(); // Assert the response assert_eq!(response.message(), Some("Hello, world!")); // Usage let usage = response.usage.unwrap(); assert_eq!(usage.prompt_tokens, 19); assert_eq!(usage.completion_tokens, 10); assert_eq!(usage.total_tokens, 29); let details = usage.details.as_ref().expect("usage details"); assert_eq!( details .prompt_tokens_details .as_ref() .and_then(|d| d.cached_tokens), Some(0) ); assert_eq!( details .completion_tokens_details .as_ref() .and_then(|d| d.reasoning_tokens), Some(0) ); let normalized = usage.normalized(); let normalized_details = normalized.details.expect("normalized details"); assert_eq!(normalized_details.input.cached_tokens, Some(0)); assert_eq!(normalized_details.output.reasoning_tokens, Some(0)); } #[test_log::test(tokio::test)] #[allow(clippy::items_after_statements)] async fn test_complete_responses_api() { use serde_json::{Value, json}; use wiremock::{Request, Respond}; let mock_server = MockServer::start().await; let response_body = json!({ "created_at": 123, "id": "resp_123", "model": "gpt-4.1-mini", "object": "response", "status": "completed", "output": [ { "type": "message", "id": "msg_1", "role": "assistant", "status": "completed", "content": [ {"type": "output_text", "text": "Hello via responses", "annotations": []} ] } ], "usage": { "input_tokens": 5, "input_tokens_details": {"cached_tokens": 0}, "output_tokens": 3, "output_tokens_details": {"reasoning_tokens": 0}, "total_tokens": 8 } }); struct ValidateResponsesRequest { expected_model: &'static str, response: Value, } impl Respond for ValidateResponsesRequest { fn respond(&self, request: &Request) -> ResponseTemplate { let body: Value = serde_json::from_slice(&request.body).unwrap(); assert_eq!(body["model"], self.expected_model); let input = body["input"].as_array().expect("input array"); assert_eq!(input.len(), 1); assert_eq!(input[0]["role"], "user"); assert_eq!(input[0]["content"], "Hello via prompt"); let _: async_openai::types::responses::Response = serde_json::from_value(self.response.clone()).unwrap(); ResponseTemplate::new(200).set_body_json(self.response.clone()) } } Mock::given(method("POST")) .and(path("/responses")) .respond_with(ValidateResponsesRequest { expected_model: "gpt-4.1-mini", response: response_body, }) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let async_openai = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(async_openai) .default_prompt_model("gpt-4.1-mini") .use_responses_api(true) .build() .expect("Can create OpenAI client."); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Hello via prompt".into())]) .build() .unwrap(); let response = openai.complete(&request).await.unwrap(); assert_eq!(response.message(), Some("Hello via responses")); let usage = response.usage.expect("usage present"); assert_eq!(usage.prompt_tokens, 5); assert_eq!(usage.completion_tokens, 3); assert_eq!(usage.total_tokens, 8); let details = usage.details.as_ref().expect("usage details"); assert_eq!( details .input_tokens_details .as_ref() .and_then(|d| d.cached_tokens), Some(0) ); assert_eq!( details .output_tokens_details .as_ref() .and_then(|d| d.reasoning_tokens), Some(0) ); let normalized = usage.normalized(); let normalized_details = normalized.details.expect("normalized details"); assert_eq!(normalized_details.input.cached_tokens, Some(0)); assert_eq!(normalized_details.output.reasoning_tokens, Some(0)); } #[test_log::test(tokio::test)] #[allow(clippy::items_after_statements)] async fn test_complete_with_all_default_settings() { use serde_json::Value; use wiremock::{Request, Respond, ResponseTemplate}; let mock_server = wiremock::MockServer::start().await; // Custom matcher to validate all settings in the incoming request struct ValidateAllSettings; impl Respond for ValidateAllSettings { fn respond(&self, request: &Request) -> ResponseTemplate { let v: Value = serde_json::from_slice(&request.body).unwrap(); // Validate required fields assert_eq!(v["model"], "gpt-4-turbo"); let arr = v["messages"].as_array().unwrap(); assert_eq!(arr.len(), 1); assert_eq!(arr[0]["content"], "Test"); assert_eq!(v["parallel_tool_calls"], true); assert_eq!(v["max_completion_tokens"], 77); assert!((v["temperature"].as_f64().unwrap() - 0.42).abs() < 1e-5); assert_eq!(v["reasoning_effort"], serde_json::Value::Null); assert_eq!(v["seed"], 42); assert!((v["presence_penalty"].as_f64().unwrap() - 1.1).abs() < 1e-5); // Metadata as JSON object and user string assert_eq!(v["metadata"], serde_json::json!({"key": "value"})); assert_eq!(v["user"], "test-user"); ResponseTemplate::new(200).set_body_json(serde_json::json!({ "id": "chatcmpl-xxx", "object": "chat.completion", "created": 123, "model": "gpt-4-turbo", "choices": [{ "index": 0, "message": { "role": "assistant", "content": "All settings validated", "refusal": null, "annotations": [] }, "logprobs": null, "finish_reason": "stop" }], "usage": { "prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0} }, "service_tier": "default" })) } } wiremock::Mock::given(wiremock::matchers::method("POST")) .and(wiremock::matchers::path("/chat/completions")) .respond_with(ValidateAllSettings) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let async_openai = async_openai::Client::with_config(config); let openai = crate::openai::OpenAI::builder() .client(async_openai) .default_prompt_model("gpt-4-turbo") .default_embed_model("not-used") .parallel_tool_calls(Some(true)) .default_options( Options::builder() .max_completion_tokens(77) .temperature(0.42) .reasoning_effort(async_openai::types::responses::ReasoningEffort::Low) .seed(42) .presence_penalty(1.1) .metadata(serde_json::json!({"key": "value"})) .user("test-user"), ) .build() .expect("Can create OpenAI client."); let request = swiftide_core::chat_completion::ChatCompletionRequest::builder() .messages(vec![swiftide_core::chat_completion::ChatMessage::User( "Test".into(), )]) .build() .unwrap(); let response = openai.complete(&request).await.unwrap(); assert_eq!(response.message(), Some("All settings validated")); } #[test_log::test(tokio::test)] async fn test_complete_with_tools_sets_auto_choice_and_parallel_calls() { use serde_json::Value; use wiremock::{Request, Respond, ResponseTemplate}; #[derive(schemars::JsonSchema)] struct WeatherArgs { _city: String, } let weather_tool = ToolSpec::builder() .name("get_weather") .description("weather") .parameters_schema(schemars::schema_for!(WeatherArgs)) .build() .unwrap(); let alpha_tool = ToolSpec::builder() .name("alpha_tool") .description("alpha") .parameters_schema(schemars::schema_for!(WeatherArgs)) .build() .unwrap(); let mock_server = MockServer::start().await; let response_body = json!({ "id": "chatcmpl-xyz", "object": "chat.completion", "created": 1, "model": "gpt-4o", "choices": [{ "index": 0, "message": { "role": "assistant", "content": "Here", "refusal": null, "annotations": [] }, "finish_reason": "stop" }], "usage": { "prompt_tokens": 2, "completion_tokens": 3, "total_tokens": 5, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0} } }); #[allow(clippy::items_after_statements)] struct Validate(Value); #[allow(clippy::items_after_statements)] impl Respond for Validate { fn respond(&self, request: &Request) -> ResponseTemplate { let v: Value = serde_json::from_slice(&request.body).unwrap(); assert_eq!(v["model"], "gpt-4o"); assert_eq!(v["parallel_tool_calls"], true); assert_eq!(v["tool_choice"], "auto"); let tools = v["tools"].as_array().unwrap(); assert_eq!(tools.len(), 2); let tool_names = tools .iter() .map(|tool| tool["function"]["name"].as_str().unwrap()) .collect::<Vec<_>>(); assert_eq!(tool_names, vec!["alpha_tool", "get_weather"]); ResponseTemplate::new(200) .insert_header("content-type", "application/json") .set_body_json(self.0.clone()) } } Mock::given(method("POST")) .and(path("/chat/completions")) .respond_with(Validate(response_body.clone())) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let async_openai = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(async_openai) .default_prompt_model("gpt-4o") .parallel_tool_calls(Some(true)) .build() .unwrap(); let req = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs([weather_tool, alpha_tool]) .build() .unwrap(); let resp = openai.complete(&req).await.unwrap(); assert_eq!(resp.message(), Some("Here")); assert_eq!(resp.usage.unwrap().total_tokens, 5); } #[test_log::test(tokio::test)] async fn test_complete_stream_happy_path() { let mock_server = MockServer::start().await; let sse_body = "\ data: {\"id\":\"chatcmpl-123\",\"created\":1,\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}\n\ \n\ data: {\"id\":\"chatcmpl-123\",\"created\":1,\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\ \n\ data: [DONE]\n\n"; Mock::given(method("POST")) .and(path("/chat/completions")) .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream")) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let async_openai = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(async_openai) .default_prompt_model("gpt-4o-mini") .build() .unwrap(); let req = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Hello".into())]) .build() .unwrap(); let results: Vec<_> = openai.complete_stream(&req).await.collect().await; let last = results.last().unwrap().as_ref().unwrap(); assert_eq!(last.message(), Some("Hi")); assert_eq!(last.usage.as_ref().map(|u| u.total_tokens), Some(3)); } #[test_log::test(tokio::test)] async fn test_complete_stream_delta_only_mode() { let mock_server = MockServer::start().await; let sse_body = "\ data: {\"id\":\"chatcmpl-123\",\"created\":1,\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}\n\ \n\ data: {\"id\":\"chatcmpl-123\",\"created\":1,\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\ \n\ data: [DONE]\n\n"; Mock::given(method("POST")) .and(path("/chat/completions")) .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream")) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let async_openai = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(async_openai) .default_prompt_model("gpt-4o-mini") .stream_full(false) .build() .unwrap(); let req = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("Hello".into())]) .build() .unwrap(); let mut stream = openai.complete_stream(&req).await; let first = stream.next().await.unwrap().unwrap(); assert!(first.message.is_none()); assert!(first.usage.is_none()); assert!( first.delta.is_some(), "delta-only mode should emit delta snapshots" ); let final_snapshot = stream.next().await.unwrap().unwrap(); // Final snapshot should arrive in delta-only mode. assert!(final_snapshot.usage.is_none() || final_snapshot.usage.is_some()); while let Some(item) = stream.next().await { item.expect("stream should not error"); } } #[test_log::test(tokio::test)] async fn test_complete_stream_invalid_tool_schema_errors() { let invalid_schema = schemars::Schema::from(true); let err = ToolSpec::builder() .name("bad") .description("bad schema") .parameters_schema(invalid_schema) .build() .expect_err("invalid tool schemas should be rejected at build time"); assert!( err.to_string() .contains("tool schema must be a JSON object") ); } #[test_log::test(tokio::test)] async fn test_complete_invalid_tool_schema_errors() { let invalid_schema = schemars::Schema::from(true); let err = ToolSpec::builder() .name("bad") .description("bad schema") .parameters_schema(invalid_schema) .build() .expect_err("invalid tool schemas should be rejected at build time"); assert!( err.to_string() .contains("tool schema must be a JSON object") ); } #[test_log::test(tokio::test)] async fn test_complete_stream_rate_limit_transient_error() { let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path("/chat/completions")) .respond_with(ResponseTemplate::new(429).set_body_string("rate limit")) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let async_openai = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(async_openai) .default_prompt_model("gpt-4o-mini") .build() .unwrap(); let req = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .build() .unwrap(); let mut stream = openai.complete_stream(&req).await; let first = stream.next().await.expect("stream yields one item"); assert!(matches!(first, Err(LanguageModelError::TransientError(_)))); assert!(stream.next().await.is_none()); } #[test] fn test_message_to_openai_tool_output_without_content() { let tool_call = ToolCallBuilder::default() .id("call_1") .name("noop") .build() .unwrap(); let msg = ChatMessage::ToolOutput(tool_call, ToolOutput::stop()); let converted = message_to_openai(&msg) .expect("conversion succeeds") .expect("message is not filtered"); match converted { async_openai::types::chat::ChatCompletionRequestMessage::Tool(m) => { assert_eq!(m.tool_call_id, "call_1"); assert_eq!( m.content, async_openai::types::chat::ChatCompletionRequestToolMessageContent::Text( String::new() ) ); } other => panic!("expected tool message, got {other:?}"), } } #[test] fn test_message_to_openai_assistant_with_tool_calls_and_text() { let tool_call = ToolCallBuilder::default() .id("call_2") .name("math") .args("{\"x\":1}") .build() .unwrap(); let msg = ChatMessage::new_assistant(Some("pending"), Some(vec![tool_call.clone()])); let converted = message_to_openai(&msg) .expect("conversion succeeds") .expect("message is not filtered"); match converted { async_openai::types::chat::ChatCompletionRequestMessage::Assistant(m) => { assert_eq!(m.content.unwrap(), "pending".into()); let calls = m.tool_calls.unwrap(); assert_eq!(calls.len(), 1); let async_openai::types::chat::ChatCompletionMessageToolCalls::Function(call) = &calls[0] else { panic!("expected function tool call"); }; assert_eq!(call.id, "call_2"); assert_eq!(call.function.name, "math"); assert_eq!(call.function.arguments, "{\"x\":1}"); } other => panic!("expected assistant message, got {other:?}"), } } #[test_log::test(tokio::test)] async fn test_complete_stream_model_missing_errors_immediately() { let openai = OpenAI::builder() .default_embed_model("unused") .build() .expect("builder without prompt model still constructs"); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .build() .unwrap(); let mut stream = openai.complete_stream(&request).await; let first = stream.next().await.expect("stream yields one item"); assert!( matches!(first, Err(LanguageModelError::PermanentError(msg)) if msg.to_string().contains("Model not set")) ); assert!(stream.next().await.is_none(), "stream ends after error"); } #[test_log::test(tokio::test)] async fn test_track_completion_invokes_on_usage_callback() { use std::sync::atomic::{AtomicUsize, Ordering}; let hits = Arc::new(AtomicUsize::new(0)); let hits_clone = hits.clone(); let openai = OpenAI::builder() .default_prompt_model("gpt-4o") .on_usage(move |_usage| { hits_clone.fetch_add(1, Ordering::SeqCst); Ok(()) }) .build() .unwrap(); let usage = UsageBuilder::default() .prompt_tokens(1) .completion_tokens(1) .total_tokens(2) .build() .unwrap(); openai.track_completion( "gpt-4o", Some(&usage), Option::<&()>::None, Option::<&()>::None, ); // give spawned task a tick tokio::time::sleep(std::time::Duration::from_millis(10)).await; assert_eq!(hits.load(Ordering::SeqCst), 1); } } ================================================ FILE: swiftide-integrations/src/openai/embed.rs ================================================ use async_openai::types::embeddings::{CreateEmbeddingRequest, CreateEmbeddingResponse}; use async_trait::async_trait; use swiftide_core::{ EmbeddingModel, Embeddings, chat_completion::{Usage, errors::LanguageModelError}, }; use super::GenericOpenAI; use crate::openai::openai_error_to_language_model_error; #[async_trait] impl< C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone + 'static, > EmbeddingModel for GenericOpenAI<C> { async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> { let model = self .default_options .embed_model .as_ref() .ok_or(LanguageModelError::PermanentError("Model not set".into()))?; let request = self .embed_request_defaults() .model(model) .input(&input) .build() .map_err(LanguageModelError::permanent)?; tracing::debug!( num_chunks = input.len(), model = &model, "[Embed] Request to openai" ); let response = self .client .embeddings() .create(request) .await .map_err(openai_error_to_language_model_error)?; let usage = Usage::from(&response.usage); // Only track usage for embedding calls, as requests and responses are extremely verbose self.track_completion( model, Some(&usage), None::<&CreateEmbeddingRequest>, None::<&CreateEmbeddingResponse>, ); let num_embeddings = response.data.len(); tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai"); // WARN: Naively assumes that the order is preserved. Might not always be the case. Ok(response.data.into_iter().map(|d| d.embedding).collect()) } } #[cfg(test)] mod tests { use super::*; use crate::openai::OpenAI; use serde_json::json; use wiremock::{ Mock, MockServer, Request, Respond, ResponseTemplate, matchers::{method, path}, }; #[test_log::test(tokio::test)] async fn test_embed_returns_error_when_model_missing() { let openai = OpenAI::builder().build().unwrap(); let err = openai.embed(vec!["text".into()]).await.unwrap_err(); assert!(matches!(err, LanguageModelError::PermanentError(_))); } #[allow(clippy::items_after_statements)] #[test_log::test(tokio::test)] async fn test_embed_success() { let mock_server = MockServer::start().await; let response_body = json!({ "data": [{ "embedding": [0.1, 0.2], "index": 0, "object": "embedding" }], "model": "text-embedding-3-small", "object": "list", "usage": {"prompt_tokens": 5, "total_tokens": 5} }); struct ValidateEmbeddingRequest(serde_json::Value); impl Respond for ValidateEmbeddingRequest { fn respond(&self, request: &Request) -> ResponseTemplate { let body: serde_json::Value = serde_json::from_slice(&request.body).unwrap(); assert_eq!(body["model"], "text-embedding-3-small"); assert!(body["input"].is_array()); ResponseTemplate::new(200).set_body_json(self.0.clone()) } } Mock::given(method("POST")) .and(path("/embeddings")) .respond_with(ValidateEmbeddingRequest(response_body)) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let client = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(client) .default_embed_model("text-embedding-3-small") .build() .unwrap(); let embeddings = openai .embed(vec!["Hello".into(), "World".into()]) .await .unwrap(); assert_eq!(embeddings.len(), 1); assert_eq!(embeddings[0], vec![0.1, 0.2]); } } ================================================ FILE: swiftide-integrations/src/openai/mod.rs ================================================ //! This module provides integration with `OpenAI`'s API, enabling the use of language models and //! embeddings within the Swiftide project. It includes the `OpenAI` struct for managing API clients //! and default options for embedding and prompt models. The module is conditionally compiled based //! on the "openai" feature flag. use async_openai::error::{OpenAIError, StreamError}; use async_openai::types::chat::CreateChatCompletionRequestArgs; use async_openai::types::embeddings::CreateEmbeddingRequestArgs; use derive_builder::Builder; use reqwest::StatusCode; use reqwest_eventsource::Error as EventSourceError; use std::pin::Pin; use std::sync::Arc; use swiftide_core::chat_completion::Usage; use swiftide_core::chat_completion::errors::LanguageModelError; mod chat_completion; mod embed; mod responses_api; mod simple_prompt; mod structured_prompt; mod tool_schema; // expose type aliases to simplify downstream use of the open ai builder invocations pub use async_openai::config::AzureConfig; pub use async_openai::config::OpenAIConfig; pub use async_openai::types::responses::ReasoningEffort; #[cfg(feature = "tiktoken")] use crate::tiktoken::TikToken; #[cfg(feature = "tiktoken")] use anyhow::Result; #[cfg(feature = "tiktoken")] use swiftide_core::Estimatable; #[cfg(feature = "tiktoken")] use swiftide_core::EstimateTokens; /// The `OpenAI` struct encapsulates an `OpenAI` client and default options for embedding and prompt /// models. It uses the `Builder` pattern for flexible and customizable instantiation. /// /// # Example /// /// ```no_run /// # use swiftide_integrations::openai::{OpenAI, Options}; /// # use swiftide_integrations::openai::OpenAIConfig; /// /// // Create an OpenAI client with default options. The client will use the OPENAI_API_KEY environment variable. /// let openai = OpenAI::builder() /// .default_embed_model("text-embedding-3-small") /// .default_prompt_model("gpt-4") /// .build().unwrap(); /// /// // Create an OpenAI client with a custom api key. /// let openai = OpenAI::builder() /// .default_embed_model("text-embedding-3-small") /// .default_prompt_model("gpt-4") /// .client(async_openai::Client::with_config(async_openai::config::OpenAIConfig::default().with_api_key("my-api-key"))) /// .build().unwrap(); /// /// // Create an OpenAI client with custom options /// let openai = OpenAI::builder() /// .default_embed_model("text-embedding-3-small") /// .default_prompt_model("gpt-4") /// .default_options( /// Options::builder() /// .temperature(1.0) /// .parallel_tool_calls(false) /// .user("MyUserId") /// ) /// .build().unwrap(); /// ``` pub type OpenAI = GenericOpenAI<OpenAIConfig>; pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>; #[derive(Builder, Clone)] #[builder(setter(into, strip_option))] /// Generic client for `OpenAI` APIs. pub struct GenericOpenAI< C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig, > { /// The `OpenAI` client, wrapped in an `Arc` for thread-safe reference counting. /// Defaults to a new instance of `async_openai::Client`. #[builder( default = "Arc::new(async_openai::Client::<C>::default())", setter(custom) )] client: Arc<async_openai::Client<C>>, /// Default options for embedding and prompt models. #[builder(default, setter(custom))] pub(crate) default_options: Options, #[cfg(feature = "tiktoken")] #[cfg_attr(feature = "tiktoken", builder(default))] pub(crate) tiktoken: TikToken, /// Convenience option to stream the full response. Defaults to true, because nobody has time /// to reconstruct the delta. Disabling this will make the streamed content only return the /// delta, for when performance matters. This only has effect when streaming is enabled. #[builder(default = true)] pub stream_full: bool, #[cfg(feature = "metrics")] #[builder(default)] /// Optional metadata to attach to metrics emitted by this client. metric_metadata: Option<std::collections::HashMap<String, String>>, /// Opt-in flag to use `OpenAI`'s Responses API instead of the legacy Chat Completions API. #[builder(default)] pub(crate) use_responses_api: bool, /// A callback function that is called when usage information is available. #[builder(default, setter(custom))] #[allow(clippy::type_complexity)] on_usage: Option< Arc< dyn for<'a> Fn( &'a Usage, ) -> Pin< Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>, > + Send + Sync, >, >, } impl<C: async_openai::config::Config + Default + std::fmt::Debug> std::fmt::Debug for GenericOpenAI<C> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("GenericOpenAI") .field("client", &self.client) .field("default_options", &self.default_options) .field("stream_full", &self.stream_full) .field("use_responses_api", &self.use_responses_api) .finish_non_exhaustive() } } /// The `Options` struct holds configuration options for the `OpenAI` client. /// It includes optional fields for specifying the embedding and prompt models. #[derive(Debug, Clone, Builder, Default)] #[builder(setter(strip_option))] pub struct Options { /// The default embedding model to use, if specified. #[builder(default, setter(into))] pub embed_model: Option<String>, /// The default prompt model to use, if specified. #[builder(default, setter(into))] pub prompt_model: Option<String>, #[builder(default)] /// Option to enable or disable parallel tool calls for completions. /// /// At this moment, o1 and o3-mini do not support it and should be set to `None`. pub parallel_tool_calls: Option<bool>, /// Maximum number of tokens to generate in the completion. /// /// By default, the limit is disabled #[builder(default)] pub max_completion_tokens: Option<u32>, /// Temperature setting for the model. #[builder(default)] pub temperature: Option<f32>, /// Reasoning effor for reasoning models. #[builder(default, setter(into))] pub reasoning_effort: Option<ReasoningEffort>, /// Enable reasoning summary/encrypted content handling for the Responses API. /// /// This is enabled by default, but only takes effect when `reasoning_effort` is set. /// Disable it with `reasoning_features(false)` if you do not want summaries or encrypted /// reasoning stored and replayed. /// /// Note: reasoning summaries/encrypted content require an `OpenAI` organization that is /// verified for reasoning access; unverified orgs may receive no summaries. #[builder(default, setter(into))] pub reasoning_features: Option<bool>, /// This feature is in Beta. If specified, our system will make a best effort to sample /// deterministically, such that repeated requests with the same seed and parameters should /// return the same result. Determinism is not guaranteed, and you should refer to the /// `system_fingerprint` response parameter to monitor changes in the backend. #[builder(default)] pub seed: Option<i64>, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they /// appear in the text so far, increasing the model’s likelihood to talk about new topics. #[builder(default)] pub presence_penalty: Option<f32>, /// Developer-defined tags and values used for filtering completions in the dashboard. #[builder(default, setter(into))] pub metadata: Option<serde_json::Value>, /// A unique identifier representing your end-user, which can help `OpenAI` to monitor and /// detect abuse. #[builder(default, setter(into))] pub user: Option<String>, #[builder(default)] /// The number of dimensions the resulting output embeddings should have. Only supported in /// text-embedding-3 and later models. pub dimensions: Option<u32>, } impl Options { /// Creates a new `OptionsBuilder` for constructing `Options` instances. pub fn builder() -> OptionsBuilder { OptionsBuilder::default() } /// Extends options with other options pub fn merge(&mut self, other: &Options) { if let Some(embed_model) = &other.embed_model { self.embed_model = Some(embed_model.clone()); } if let Some(prompt_model) = &other.prompt_model { self.prompt_model = Some(prompt_model.clone()); } if let Some(parallel_tool_calls) = other.parallel_tool_calls { self.parallel_tool_calls = Some(parallel_tool_calls); } if let Some(max_completion_tokens) = other.max_completion_tokens { self.max_completion_tokens = Some(max_completion_tokens); } if let Some(temperature) = other.temperature { self.temperature = Some(temperature); } if let Some(reasoning_effort) = &other.reasoning_effort { self.reasoning_effort = Some(reasoning_effort.clone()); } if let Some(reasoning_features) = other.reasoning_features { self.reasoning_features = Some(reasoning_features); } if let Some(seed) = other.seed { self.seed = Some(seed); } if let Some(presence_penalty) = other.presence_penalty { self.presence_penalty = Some(presence_penalty); } if let Some(metadata) = &other.metadata { self.metadata = Some(metadata.clone()); } if let Some(user) = &other.user { self.user = Some(user.clone()); } if let Some(dimensions) = other.dimensions { self.dimensions = Some(dimensions); } } } impl From<OptionsBuilder> for Options { fn from(value: OptionsBuilder) -> Self { Self { embed_model: value.embed_model.flatten(), prompt_model: value.prompt_model.flatten(), parallel_tool_calls: value.parallel_tool_calls.flatten(), max_completion_tokens: value.max_completion_tokens.flatten(), temperature: value.temperature.flatten(), reasoning_effort: value.reasoning_effort.flatten(), reasoning_features: value.reasoning_features.flatten(), presence_penalty: value.presence_penalty.flatten(), seed: value.seed.flatten(), metadata: value.metadata.flatten(), user: value.user.flatten(), dimensions: value.dimensions.flatten(), } } } impl From<&mut OptionsBuilder> for Options { fn from(value: &mut OptionsBuilder) -> Self { let value = value.clone(); Self { embed_model: value.embed_model.flatten(), prompt_model: value.prompt_model.flatten(), parallel_tool_calls: value.parallel_tool_calls.flatten(), max_completion_tokens: value.max_completion_tokens.flatten(), temperature: value.temperature.flatten(), reasoning_effort: value.reasoning_effort.flatten(), reasoning_features: value.reasoning_features.flatten(), presence_penalty: value.presence_penalty.flatten(), seed: value.seed.flatten(), metadata: value.metadata.flatten(), user: value.user.flatten(), dimensions: value.dimensions.flatten(), } } } impl OpenAI { /// Creates a new `OpenAIBuilder` for constructing `OpenAI` instances. pub fn builder() -> OpenAIBuilder { let mut builder = OpenAIBuilder::default(); builder.default_options(Options { reasoning_features: Some(true), ..Default::default() }); builder } } impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug> GenericOpenAIBuilder<C> { /// Adds a callback function that will be called when usage information is available. pub fn on_usage<F>(&mut self, func: F) -> &mut Self where F: Fn(&Usage) -> anyhow::Result<()> + Send + Sync + 'static, { let func = Arc::new(func); self.on_usage = Some(Some(Arc::new(move |usage: &Usage| { let func = func.clone(); Box::pin(async move { func(usage) }) }))); self } /// Adds an asynchronous callback function that will be called when usage information is /// available. pub fn on_usage_async<F>(&mut self, func: F) -> &mut Self where F: for<'a> Fn( &'a Usage, ) -> Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>> + Send + Sync + 'static, { let func = Arc::new(func); self.on_usage = Some(Some(Arc::new(move |usage: &Usage| { let func = func.clone(); Box::pin(async move { func(usage).await }) }))); self } /// Sets the `OpenAI` client for the `OpenAI` instance. /// /// # Parameters /// - `client`: The `OpenAI` client to set. /// /// # Returns /// A mutable reference to the `OpenAIBuilder`. pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self { self.client = Some(Arc::new(client)); self } /// Sets the default embedding model for the `OpenAI` instance. /// /// # Parameters /// - `model`: The embedding model to set. /// /// # Returns /// A mutable reference to the `OpenAIBuilder`. pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self { if let Some(options) = self.default_options.as_mut() { options.embed_model = Some(model.into()); } else { self.default_options = Some(Options { embed_model: Some(model.into()), ..Default::default() }); } self } /// Sets the `user` field used by `OpenAI` to monitor and detect usage and abuse. pub fn for_end_user(&mut self, user: impl Into<String>) -> &mut Self { if let Some(options) = self.default_options.as_mut() { options.user = Some(user.into()); } else { self.default_options = Some(Options { user: Some(user.into()), ..Default::default() }); } self } /// Enable or disable parallel tool calls for completions. /// /// Note that currently reasoning models do not support parallel tool calls /// /// Defaults to `true` pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self { if let Some(options) = self.default_options.as_mut() { options.parallel_tool_calls = parallel_tool_calls; } else { self.default_options = Some(Options { parallel_tool_calls, ..Default::default() }); } self } /// Sets the default prompt model for the `OpenAI` instance. /// /// # Parameters /// - `model`: The prompt model to set. /// /// # Returns /// A mutable reference to the `OpenAIBuilder`. pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self { if let Some(options) = self.default_options.as_mut() { options.prompt_model = Some(model.into()); } else { self.default_options = Some(Options { prompt_model: Some(model.into()), ..Default::default() }); } self } /// Sets the default options to use for requests to the `OpenAI` API. /// /// Merges with any existing options pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self { if let Some(existing_options) = self.default_options.as_mut() { existing_options.merge(&options.into()); } else { self.default_options = Some(options.into()); } self } } impl<C: async_openai::config::Config + Default> GenericOpenAI<C> { /// Estimates the number of tokens for implementors of the `Estimatable` trait. /// /// I.e. `String`, `ChatMessage` etc /// /// # Errors /// /// Errors if tokinization fails in any way #[cfg(feature = "tiktoken")] pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> { self.tiktoken.estimate(value).await } pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self { self.default_options = Options { prompt_model: Some(model.into()), ..self.default_options.clone() }; self } pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self { self.default_options = Options { embed_model: Some(model.into()), ..self.default_options.clone() }; self } /// Retrieve a reference to the inner `OpenAI` client. pub fn client(&self) -> &Arc<async_openai::Client<C>> { &self.client } /// Retrieve a reference to the default options for the `OpenAI` instance. pub fn options(&self) -> &Options { &self.default_options } /// Retrieve a mutable reference to the default options for the `OpenAI` instance. pub fn options_mut(&mut self) -> &mut Options { &mut self.default_options } /// Returns whether the Responses API is enabled for this client. pub fn is_responses_api_enabled(&self) -> bool { self.use_responses_api } fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs { let mut args = CreateChatCompletionRequestArgs::default(); let options = &self.default_options; if let Some(parallel_tool_calls) = options.parallel_tool_calls { args.parallel_tool_calls(parallel_tool_calls); } if let Some(max_tokens) = options.max_completion_tokens { args.max_completion_tokens(max_tokens); } if let Some(temperature) = options.temperature { args.temperature(temperature); } if let Some(seed) = options.seed { args.seed(seed); } if let Some(presence_penalty) = options.presence_penalty { args.presence_penalty(presence_penalty); } if let Some(metadata) = &options.metadata { args.metadata(metadata.clone()); } if let Some(user) = &options.user { args.user(user.clone()); } args } fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs { let mut args = CreateEmbeddingRequestArgs::default(); let options = &self.default_options; if let Some(user) = &options.user { args.user(user.clone()); } if let Some(dimensions) = options.dimensions { args.dimensions(dimensions); } args } } pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError { match e { OpenAIError::ApiError(api_error) => { // If the response is an ApiError, it could be a context length exceeded error if api_error.code == Some("context_length_exceeded".to_string()) { LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error)) } else { LanguageModelError::permanent(OpenAIError::ApiError(api_error)) } } OpenAIError::Reqwest(e) => { // async_openai passes any network errors as reqwest errors, so we just assume they are // recoverable LanguageModelError::transient(e) } OpenAIError::JSONDeserialize(_, _) => { // OpenAI generated a non-json response, probably a temporary problem on their side // (i.e. reverse proxy can't find an available backend) LanguageModelError::transient(e) } OpenAIError::StreamError(stream_error) => { // Note that this will _retry_ the stream. We have to assume that the stream just // started if a 429 happens. For future readers, internally the streaming crate // (eventsource) already applies backoff. if is_rate_limited_stream_error(&stream_error) { LanguageModelError::transient(OpenAIError::StreamError(stream_error)) } else { LanguageModelError::permanent(OpenAIError::StreamError(stream_error)) } } OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e), } } fn is_rate_limited_stream_error(error: &StreamError) -> bool { match error { StreamError::ReqwestEventSource(inner) => match inner { EventSourceError::InvalidStatusCode(status, _) => { *status == StatusCode::TOO_MANY_REQUESTS } EventSourceError::Transport(source) => { source.status() == Some(StatusCode::TOO_MANY_REQUESTS) } _ => false, }, StreamError::UnknownEvent(_) | StreamError::EventStream(_) => false, } } #[cfg(test)] mod test { use super::*; use async_openai::error::{ApiError, OpenAIError, StreamError}; use eventsource_stream::Event; /// test default embed model #[test] fn test_default_embed_and_prompt_model() { let openai: OpenAI = OpenAI::builder() .default_embed_model("gpt-3") .default_prompt_model("gpt-4") .build() .unwrap(); assert_eq!( openai.default_options.embed_model, Some("gpt-3".to_string()) ); assert_eq!( openai.default_options.prompt_model, Some("gpt-4".to_string()) ); let openai: OpenAI = OpenAI::builder() .default_prompt_model("gpt-4") .default_embed_model("gpt-3") .build() .unwrap(); assert_eq!( openai.default_options.prompt_model, Some("gpt-4".to_string()) ); assert_eq!( openai.default_options.embed_model, Some("gpt-3".to_string()) ); } #[test] fn test_use_responses_api_flag() { let openai: OpenAI = OpenAI::builder().use_responses_api(true).build().unwrap(); assert!(openai.is_responses_api_enabled()); } #[test] fn test_context_length_exceeded_error() { // Create an API error with the context_length_exceeded code let api_error = ApiError { message: "This model's maximum context length is 8192 tokens".to_string(), r#type: Some("invalid_request_error".to_string()), param: Some("messages".to_string()), code: Some("context_length_exceeded".to_string()), }; let openai_error = OpenAIError::ApiError(api_error); let result = openai_error_to_language_model_error(openai_error); // Verify it's categorized as ContextLengthExceeded match result { LanguageModelError::ContextLengthExceeded(_) => {} // Expected _ => panic!("Expected ContextLengthExceeded error, got {result:?}"), } } #[test] fn test_api_error_permanent() { // Create a generic API error (not context length exceeded) let api_error = ApiError { message: "Invalid API key".to_string(), r#type: Some("invalid_request_error".to_string()), param: Some("api_key".to_string()), code: Some("invalid_api_key".to_string()), }; let openai_error = OpenAIError::ApiError(api_error); let result = openai_error_to_language_model_error(openai_error); // Verify it's categorized as PermanentError match result { LanguageModelError::PermanentError(_) => {} // Expected _ => panic!("Expected PermanentError, got {result:?}"), } } #[test] fn test_file_save_error_is_permanent() { // Create a file save error let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string()); let result = openai_error_to_language_model_error(openai_error); // Verify it's categorized as PermanentError match result { LanguageModelError::PermanentError(_) => {} // Expected _ => panic!("Expected PermanentError, got {result:?}"), } } #[test] fn test_file_read_error_is_permanent() { // Create a file read error let openai_error = OpenAIError::FileReadError("Failed to read file".to_string()); let result = openai_error_to_language_model_error(openai_error); // Verify it's categorized as PermanentError match result { LanguageModelError::PermanentError(_) => {} // Expected _ => panic!("Expected PermanentError, got {result:?}"), } } #[test] fn test_stream_error_is_permanent() { // Create a stream error let openai_error = OpenAIError::StreamError(Box::new(StreamError::UnknownEvent(Event::default()))); let result = openai_error_to_language_model_error(openai_error); // Verify it's categorized as PermanentError match result { LanguageModelError::PermanentError(_) => {} // Expected _ => panic!("Expected PermanentError, got {result:?}"), } } #[test] fn test_invalid_argument_is_permanent() { // Create an invalid argument error let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string()); let result = openai_error_to_language_model_error(openai_error); // Verify it's categorized as PermanentError match result { LanguageModelError::PermanentError(_) => {} // Expected _ => panic!("Expected PermanentError, got {result:?}"), } } #[test] fn test_options_merge_overrides_set_fields() { let mut base = Options::builder() .prompt_model("a") .temperature(0.1) .build() .unwrap(); let overlay = Options::builder() .prompt_model("b") .presence_penalty(0.2) .build() .unwrap(); base.merge(&overlay); assert_eq!(base.prompt_model.as_deref(), Some("b")); assert_eq!(base.temperature, Some(0.1)); assert_eq!(base.presence_penalty, Some(0.2)); } #[test] #[allow(deprecated)] fn test_chat_completion_request_defaults_omits_reasoning_effort() { let openai: OpenAI = OpenAI::builder() .default_options( Options::builder() .parallel_tool_calls(true) .max_completion_tokens(42) .temperature(0.3) .reasoning_effort(ReasoningEffort::Low) .seed(7) .presence_penalty(1.1) .metadata(serde_json::json!({"tag": "demo"})) .user("user-1"), ) .build() .unwrap(); let built = openai .chat_completion_request_defaults() .messages(Vec::new()) .model("gpt-4o") .build() .unwrap(); assert_eq!(built.parallel_tool_calls, Some(true)); assert_eq!(built.max_completion_tokens, Some(42)); assert_eq!(built.temperature, Some(0.3)); assert_eq!(built.reasoning_effort, None); assert_eq!(built.seed, Some(7)); assert_eq!(built.presence_penalty, Some(1.1)); assert_eq!( built.metadata, Some(async_openai::types::Metadata::from( serde_json::json!({"tag": "demo"}) )) ); assert_eq!(built.user, Some("user-1".to_string())); } #[test] #[allow(deprecated)] fn test_embed_request_defaults_sets_user_and_dimensions() { let openai: OpenAI = OpenAI::builder() .default_options(Options::builder().user("end-user").dimensions(128)) .build() .unwrap(); let built = openai .embed_request_defaults() .model("text-embedding-3-small") .input("hello") .build() .unwrap(); assert_eq!(built.user, Some("end-user".to_string())); assert_eq!(built.dimensions, Some(128)); } } ================================================ FILE: swiftide-integrations/src/openai/responses_api.rs ================================================ use std::collections::HashMap; use std::pin::Pin; use std::task::{Context, Poll}; use anyhow::{Context as _, Result}; use async_openai::types::responses::{ CreateResponse, CreateResponseArgs, EasyInputContent, EasyInputMessageArgs, FunctionCallOutput, FunctionCallOutputItemParam, FunctionTool, FunctionToolCall, ImageDetail, IncludeEnum, InputContent, InputFileArgs, InputImageContent, InputItem, InputParam, InputTextContent, MessageType, OutputContent, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, ReasoningArgs, ReasoningSummary, Response, ResponseFormatJsonSchema, ResponseStream, ResponseStreamEvent, ResponseTextParam, Role, Status, TextResponseFormatConfiguration, Tool, ToolChoiceOptions, ToolChoiceParam, }; use base64::Engine as _; use futures_util::Stream; use swiftide_core::chat_completion::{ ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatMessageContentPart, ChatMessageContentSource, ReasoningItem, ToolCall, ToolOutput, ToolSpec, Usage, }; use super::tool_schema::OpenAiToolSchema; use super::{GenericOpenAI, openai_error_to_language_model_error}; use crate::openai::LanguageModelError; type LmResult<T> = Result<T, LanguageModelError>; pub(super) fn build_responses_request_from_chat<C>( client: &GenericOpenAI<C>, request: &ChatCompletionRequest<'_>, ) -> LmResult<CreateResponse> where C: async_openai::config::Config + Clone + Default, { let model = client .options() .prompt_model .as_ref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?; let mut args = base_request_args(client, model)?; let options = client.options(); let include_reasoning = options.reasoning_effort.is_some(); let input_items = chat_messages_to_input_items(request.messages(), include_reasoning)?; args.input(InputParam::Items(input_items)); if !request.tools_spec().is_empty() { let tools = request .tools_spec() .iter() .map(tool_spec_to_responses_tool) .collect::<Result<Vec<_>>>() .map_err(LanguageModelError::permanent)?; args.tools(tools); if client.options().parallel_tool_calls.unwrap_or(true) { args.tool_choice(ToolChoiceParam::Mode(ToolChoiceOptions::Auto)); } } args.build().map_err(openai_error_to_language_model_error) } fn base_request_args<C>(client: &GenericOpenAI<C>, model: &str) -> LmResult<CreateResponseArgs> where C: async_openai::config::Config + Clone + Default, { let mut args = CreateResponseArgs::default(); args.model(model); let options = client.options(); if let Some(parallel_tool_calls) = options.parallel_tool_calls { args.parallel_tool_calls(parallel_tool_calls); } if let Some(max_tokens) = options.max_completion_tokens { args.max_output_tokens(max_tokens); } if let Some(temperature) = options.temperature { args.temperature(temperature); } if let Some(reasoning_effort) = options.reasoning_effort.clone() { let mut reasoning = ReasoningArgs::default(); reasoning.effort(reasoning_effort); if options.reasoning_features.unwrap_or(true) { reasoning.summary(ReasoningSummary::Auto); args.include(vec![IncludeEnum::ReasoningEncryptedContent]); } let reasoning = reasoning.build().map_err(LanguageModelError::permanent)?; args.reasoning(reasoning); // Reasoning models should always be stateless in Responses API usage. args.store(false); } if let Some(seed) = options.seed { tracing::warn!( seed, "`seed` is not supported by the Responses API; ignoring" ); } if let Some(presence_penalty) = options.presence_penalty { tracing::warn!( presence_penalty, "`presence_penalty` is not supported by the Responses API; ignoring" ); } if let Some(metadata) = options.metadata.as_ref() { if let Some(converted) = convert_metadata(metadata) { args.metadata(converted); } else { tracing::warn!("Responses metadata must be a flat map of string values; skipping"); } } Ok(args) } fn convert_metadata(value: &serde_json::Value) -> Option<HashMap<String, String>> { match value { serde_json::Value::Object(map) => { let mut out = HashMap::with_capacity(map.len()); for (key, val) in map { if let Some(s) = val.as_str() { out.insert(key.clone(), s.to_owned()); } else { return None; } } Some(out) } _ => None, } } fn tool_spec_to_responses_tool(spec: &ToolSpec) -> Result<Tool> { let parameters = OpenAiToolSchema::try_from(spec) .context("tool schema must be OpenAI compatible")? .into_value(); let function = FunctionTool { name: spec.name.clone(), parameters: Some(parameters), strict: Some(true), description: Some(spec.description.clone()), }; Ok(Tool::Function(function)) } fn chat_messages_to_input_items( messages: &[ChatMessage], include_reasoning: bool, ) -> LmResult<Vec<InputItem>> { let mut items = Vec::with_capacity(messages.len()); for message in messages { match message { ChatMessage::System(content) => { items.push(message_item(Role::System, content.clone())?); } ChatMessage::User(content) => { items.push(message_item(Role::User, content.clone())?); } ChatMessage::UserWithParts(parts) => { let content = user_parts_to_easy_input_content(parts)?; items.push(message_item_with_content(Role::User, content)?); } ChatMessage::Assistant(content, tool_calls) => { if let Some(text) = content.as_ref() { items.push(message_item(Role::Assistant, text.clone())?); } if let Some(tool_calls) = tool_calls.as_ref() { for tool_call in tool_calls { let call_id = normalize_responses_function_call_id(tool_call.id()); let arguments = tool_call.args().unwrap_or_default().to_owned(); let function_call = FunctionToolCall { arguments, call_id: call_id.clone(), name: tool_call.name().to_owned(), id: None, status: Some(OutputStatus::InProgress), }; items.push(InputItem::Item( async_openai::types::responses::Item::FunctionCall(function_call), )); } } } ChatMessage::ToolOutput(tool_call, tool_output) => { let output = match tool_output { ToolOutput::FeedbackRequired(value) | ToolOutput::Stop(value) | ToolOutput::AgentFailed(value) => FunctionCallOutput::Text( value .as_ref() .map_or_else(String::new, serde_json::Value::to_string), ), ToolOutput::Text(text) | ToolOutput::Fail(text) => { FunctionCallOutput::Text(text.clone()) } _ => FunctionCallOutput::Text(String::new()), }; let function_output = FunctionCallOutputItemParam { call_id: normalize_responses_function_call_id(tool_call.id()), output, id: None, status: Some(OutputStatus::Completed), }; items.push(InputItem::Item( async_openai::types::responses::Item::FunctionCallOutput(function_output), )); } ChatMessage::Reasoning(item) => { if !include_reasoning || item.encrypted_content.is_none() || item .encrypted_content .as_ref() .is_some_and(String::is_empty) { continue; } let reasoning_item = async_openai::types::responses::ReasoningItem { id: item.id.clone(), summary: Vec::new(), content: None, encrypted_content: item.encrypted_content.clone(), status: None, }; items.push(InputItem::Item( async_openai::types::responses::Item::Reasoning(reasoning_item), )); } ChatMessage::Summary(content) => { items.push(message_item(Role::Assistant, content.clone())?); } } } Ok(items) } fn message_item(role: Role, content: String) -> LmResult<InputItem> { message_item_with_content(role, EasyInputContent::Text(content)) } fn message_item_with_content(role: Role, content: EasyInputContent) -> LmResult<InputItem> { Ok(InputItem::EasyMessage( EasyInputMessageArgs::default() .r#type(MessageType::Message) .role(role) .content(content) .build() .map_err(LanguageModelError::permanent)?, )) } fn user_parts_to_easy_input_content( parts: &[ChatMessageContentPart], ) -> LmResult<EasyInputContent> { let mapped = parts .iter() .map(part_to_input_content) .collect::<LmResult<Vec<_>>>()?; Ok(EasyInputContent::ContentList(mapped)) } fn part_to_input_content(part: &ChatMessageContentPart) -> LmResult<InputContent> { Ok(match part { ChatMessageContentPart::Text { text } => { InputContent::from(InputTextContent::from(text.as_str())) } ChatMessageContentPart::Image { source, .. } => { let image = match source { ChatMessageContentSource::Url { url } => InputImageContent { detail: ImageDetail::default(), file_id: None, image_url: Some(url.clone()), }, ChatMessageContentSource::FileId { file_id } => InputImageContent { detail: ImageDetail::default(), file_id: Some(file_id.clone()), image_url: None, }, ChatMessageContentSource::Bytes { data, media_type } => { let media_type = media_type.as_deref().unwrap_or("application/octet-stream"); let encoded = base64::engine::general_purpose::STANDARD.encode(data); InputImageContent { detail: ImageDetail::default(), file_id: None, image_url: Some(format!("data:{media_type};base64,{encoded}")), } } ChatMessageContentSource::S3 { .. } => { return Err(LanguageModelError::permanent( "OpenAI responses input_image does not support s3 sources", )); } }; InputContent::from(image) } ChatMessageContentPart::Document { source, format, name, } => { let mut builder = InputFileArgs::default(); let filename = name .as_deref() .map(str::to_owned) .or_else(|| format.as_ref().map(|ext| format!("document.{ext}"))) .unwrap_or_else(|| "document".to_string()); match source { ChatMessageContentSource::Url { url } => { builder.file_url(url.as_str()); } ChatMessageContentSource::FileId { file_id } => { builder.file_id(file_id.as_str()); } ChatMessageContentSource::Bytes { data, .. } => { let encoded = base64::engine::general_purpose::STANDARD.encode(data); builder.file_data(encoded).filename(filename); } ChatMessageContentSource::S3 { .. } => { return Err(LanguageModelError::permanent( "OpenAI responses input_file does not support s3 sources", )); } } InputContent::from(builder.build().map_err(LanguageModelError::permanent)?) } ChatMessageContentPart::Audio { .. } => { return Err(LanguageModelError::permanent( "OpenAI responses API does not support audio parts in chat conversion", )); } ChatMessageContentPart::Video { .. } => { return Err(LanguageModelError::permanent( "OpenAI responses API does not support video parts in chat conversion", )); } }) } fn normalize_responses_function_call_id(id: &str) -> String { if id.starts_with("fc_") { id.to_owned() } else if let Some(stripped) = id.strip_prefix("call_") { format!("fc_{stripped}") } else { id.to_owned() } } #[derive(Default)] pub(super) struct ResponsesStreamState { response: ChatCompletionResponse, finished: bool, } #[derive(Debug, Clone)] pub(super) struct ResponsesStreamItem { pub response: ChatCompletionResponse, pub finished: bool, } impl ResponsesStreamState { #[allow(clippy::too_many_lines)] fn apply_event( &mut self, event: ResponseStreamEvent, stream_full: bool, ) -> LmResult<Option<ResponsesStreamItem>> { if self.finished { return Ok(None); } let maybe_item = match event { ResponseStreamEvent::ResponseOutputTextDelta(delta) => { self.response .append_message_delta(Some(delta.delta.as_str())); Some(self.emit(stream_full, false)) } ResponseStreamEvent::ResponseContentPartAdded(part) => match &part.part { OutputContent::OutputText(text) => { self.response.append_message_delta(Some(text.text.as_str())); Some(self.emit(stream_full, false)) } _ => None, }, ResponseStreamEvent::ResponseOutputItemAdded(event) => match event.item { OutputItem::FunctionCall(function_call) => { let index = event.output_index as usize; let id = function_call_identifier(&function_call); let arguments = (!function_call.arguments.is_empty()) .then_some(function_call.arguments.as_str()); self.response.append_tool_call_delta( index, Some(id), Some(function_call.name.as_str()), arguments, ); Some(self.emit(stream_full, false)) } OutputItem::Message(message) => { collect_message_text_from_message(&message).map(|text| { self.response.append_message_delta(Some(text.as_str())); self.emit(stream_full, false) }) } _ => None, }, ResponseStreamEvent::ResponseOutputItemDone(event) => { if let OutputItem::FunctionCall(function_call) = event.item { let index = event.output_index as usize; let id = function_call_identifier(&function_call); self.response.append_tool_call_delta( index, Some(id), Some(function_call.name.as_str()), None, ); Some(self.emit(stream_full, false)) } else { None } } ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(delta) => { let index = delta.output_index as usize; self.response .append_tool_call_delta(index, None, None, Some(delta.delta.as_str())); Some(self.emit(stream_full, false)) } ResponseStreamEvent::ResponseFunctionCallArgumentsDone(done) => { let index = done.output_index as usize; let name = done.name.as_deref().filter(|n| !n.is_empty()); let mut arguments = None; if !done.arguments.is_empty() { let new_args = done.arguments.as_str(); let duplicate = self .response .tool_calls .as_ref() .and_then(|calls| calls.get(index)) .and_then(|tc| tc.args()) .is_some_and(|existing| existing == new_args); if !duplicate { arguments = Some(new_args); } } if name.is_some() || arguments.is_some() { self.response .append_tool_call_delta(index, None, name, arguments); Some(self.emit(stream_full, false)) } else { None } } ResponseStreamEvent::ResponseCompleted(completed) => { metadata_to_chat_completion(&completed.response, &mut self.response)?; self.response.delta = None; self.finished = true; Some(self.emit(stream_full, true)) } ResponseStreamEvent::ResponseIncomplete(incomplete) => { metadata_to_chat_completion(&incomplete.response, &mut self.response)?; self.response.delta = None; self.finished = true; Some(self.emit(stream_full, true)) } ResponseStreamEvent::ResponseFailed(failed) => { self.finished = true; let message = failed.response.error.as_ref().map_or_else( || "Responses API stream failed".to_string(), |err| format!("{}: {}", err.code, err.message), ); return Err(LanguageModelError::permanent(message)); } ResponseStreamEvent::ResponseError(error) => { self.finished = true; return Err(LanguageModelError::permanent(error.message)); } _ => None, }; Ok(maybe_item) } fn emit(&mut self, stream_full: bool, finished: bool) -> ResponsesStreamItem { let response = if finished { // Stream is complete; move the accumulated response out of state. let mut response = std::mem::take(&mut self.response); response.delta = None; response } else if stream_full { self.response.clone() } else { ChatCompletionResponse { id: self.response.id, message: None, tool_calls: None, usage: None, reasoning: None, delta: self.response.delta.clone(), } }; ResponsesStreamItem { response, finished } } fn take_final(&mut self, stream_full: bool) -> Option<ResponsesStreamItem> { if self.finished { None } else { self.finished = true; Some(self.emit(stream_full, true)) } } } pub(super) fn responses_stream_adapter( stream: ResponseStream, stream_full: bool, ) -> ResponsesStreamAdapter { ResponsesStreamAdapter::new(stream, stream_full) } pub(super) struct ResponsesStreamAdapter { inner: ResponseStream, state: ResponsesStreamState, stream_full: bool, finished: bool, } impl ResponsesStreamAdapter { fn new(stream: ResponseStream, stream_full: bool) -> Self { Self { inner: stream, state: ResponsesStreamState::default(), stream_full, finished: false, } } } impl Stream for ResponsesStreamAdapter { type Item = LmResult<ResponsesStreamItem>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let this = self.get_mut(); if this.finished { return Poll::Ready(None); } loop { match this.inner.as_mut().poll_next(cx) { Poll::Ready(Some(result)) => { let event = match result { Ok(event) => event, Err(err) => { this.finished = true; return Poll::Ready(Some(Err(openai_error_to_language_model_error( err, )))); } }; match this.state.apply_event(event, this.stream_full) { Ok(Some(item)) => { if item.finished { this.finished = true; } return Poll::Ready(Some(Ok(item))); } Ok(None) => {} Err(err) => { this.finished = true; return Poll::Ready(Some(Err(err))); } } } Poll::Ready(None) => { this.finished = true; if let Some(item) = this.state.take_final(this.stream_full) { return Poll::Ready(Some(Ok(item))); } return Poll::Ready(None); } Poll::Pending => return Poll::Pending, } } } } pub(super) fn response_to_chat_completion(response: &Response) -> LmResult<ChatCompletionResponse> { if matches!(response.status, Status::Failed) { let error = response.error.as_ref().map_or_else( || "OpenAI Responses API returned failure".to_string(), |err| format!("{}: {}", err.code, err.message), ); return Err(LanguageModelError::permanent(error)); } let mut builder = ChatCompletionResponse::builder(); let reasoning_items = collect_reasoning_items_from_items(&response.output); if !reasoning_items.is_empty() { builder.reasoning(reasoning_items); } if let Some(text) = response.output_text().filter(|s| !s.is_empty()) { builder.message(text); } else if let Some(text) = collect_message_text_from_items(&response.output) { builder.message(text); } let tool_calls = collect_tool_calls_from_items(&response.output)?; if !tool_calls.is_empty() { builder.tool_calls(tool_calls); } if let Some(usage) = response.usage.as_ref() { builder.usage(Usage::from(usage)); } builder.build().map_err(LanguageModelError::from) } pub(super) fn metadata_to_chat_completion( metadata: &Response, accumulator: &mut ChatCompletionResponse, ) -> LmResult<()> { if let Some(usage) = metadata.usage.as_ref() { accumulator.usage = Some(Usage::from(usage)); } if accumulator.message.is_none() && let Some(text) = collect_message_text_from_items(&metadata.output) { accumulator.message = Some(text); } if accumulator.tool_calls.is_none() { let tool_calls = collect_tool_calls_from_items(&metadata.output)?; if !tool_calls.is_empty() { accumulator.tool_calls = Some(tool_calls); } } if accumulator.reasoning.is_none() { let reasoning_items = collect_reasoning_items_from_items(&metadata.output); if !reasoning_items.is_empty() { accumulator.reasoning = Some(reasoning_items); } } Ok(()) } fn collect_message_text_from_items(output: &[OutputItem]) -> Option<String> { let mut buffer = String::new(); for item in output { if let OutputItem::Message(OutputMessage { content, .. }) = item { for part in content { if let OutputMessageContent::OutputText(text) = part { if !buffer.is_empty() { buffer.push('\n'); } buffer.push_str(&text.text); } } } } if buffer.is_empty() { None } else { Some(buffer) } } fn collect_message_text_from_message(message: &OutputMessage) -> Option<String> { let mut buffer = String::new(); for part in &message.content { if let OutputMessageContent::OutputText(text) = part { if !buffer.is_empty() { buffer.push('\n'); } buffer.push_str(&text.text); } } if buffer.is_empty() { None } else { Some(buffer) } } fn collect_tool_calls_from_items(output: &[OutputItem]) -> LmResult<Vec<ToolCall>> { let calls = output.iter().filter_map(|item| match item { OutputItem::FunctionCall(function_call) => Some(function_call), _ => None, }); tool_calls_from_iter(calls) } fn collect_reasoning_items_from_items(output: &[OutputItem]) -> Vec<ReasoningItem> { output .iter() .filter_map(|item| match item { OutputItem::Reasoning(reasoning) => Some(ReasoningItem { id: reasoning.id.clone(), summary: reasoning .summary .iter() .map(|part| match part { async_openai::types::responses::SummaryPart::SummaryText(summary) => { summary.text.clone() } }) .collect(), content: reasoning .content .as_ref() .map(|c| c.iter().map(|c| c.text.clone()).collect()), status: { if let Some(status) = &reasoning.status { match status { OutputStatus::Completed => { Some(swiftide_core::chat_completion::ReasoningStatus::Completed) } OutputStatus::InProgress => { Some(swiftide_core::chat_completion::ReasoningStatus::InProgress) } OutputStatus::Incomplete => { Some(swiftide_core::chat_completion::ReasoningStatus::Incomplete) } } } else { None } }, encrypted_content: reasoning.encrypted_content.clone(), }), _ => None, }) .collect() } fn tool_call_from_function_call(function_call: &FunctionToolCall) -> LmResult<ToolCall> { let id = if function_call.call_id.is_empty() { function_call.id.as_deref().unwrap_or_default().to_string() } else { function_call.call_id.clone() }; let mut builder = ToolCall::builder(); builder.id(id); builder.name(function_call.name.clone()); if !function_call.arguments.is_empty() { builder.maybe_args(Some(function_call.arguments.clone())); } builder .build() .context("Failed to build tool call") .map_err(LanguageModelError::permanent) } fn tool_calls_from_iter<'a, I>(calls: I) -> LmResult<Vec<ToolCall>> where I: IntoIterator<Item = &'a FunctionToolCall>, { calls .into_iter() .map(tool_call_from_function_call) .collect::<Result<Vec<_>, _>>() } fn function_call_identifier(function_call: &FunctionToolCall) -> &str { if function_call.call_id.is_empty() { function_call .id .as_deref() .unwrap_or(function_call.call_id.as_str()) } else { function_call.call_id.as_str() } } pub(super) fn build_responses_request_from_prompt<C>( client: &GenericOpenAI<C>, prompt_text: String, ) -> LmResult<CreateResponse> where C: async_openai::config::Config + Clone + Default, { let model = client .options() .prompt_model .as_ref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?; let mut args = base_request_args(client, model)?; args.input(InputParam::Items(vec![InputItem::EasyMessage( EasyInputMessageArgs::default() .r#type(MessageType::Message) .role(Role::User) .content(EasyInputContent::Text(prompt_text)) .build() .map_err(LanguageModelError::permanent)?, )])); args.build().map_err(openai_error_to_language_model_error) } pub(super) fn build_responses_request_from_prompt_with_schema<C>( client: &GenericOpenAI<C>, prompt_text: String, schema: serde_json::Value, ) -> LmResult<CreateResponse> where C: async_openai::config::Config + Clone + Default, { let model = client .options() .prompt_model .as_ref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?; let mut args = base_request_args(client, model)?; args.input(InputParam::Items(vec![InputItem::EasyMessage( EasyInputMessageArgs::default() .r#type(MessageType::Message) .role(Role::User) .content(EasyInputContent::Text(prompt_text)) .build() .map_err(LanguageModelError::permanent)?, )])); args.text(ResponseTextParam { format: TextResponseFormatConfiguration::JsonSchema(ResponseFormatJsonSchema { description: None, name: "swiftide_structured_output".into(), schema: Some(schema), strict: Some(true), }), verbosity: None, }); args.build().map_err(openai_error_to_language_model_error) } #[allow(clippy::items_after_statements)] #[cfg(test)] mod tests { use super::*; use async_openai::types::responses::{ AssistantRole, FunctionToolCall, IncludeEnum, InputTokenDetails, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, OutputTextContent, OutputTokenDetails, ReasoningEffort, ReasoningSummary, ResponseCompletedEvent, ResponseErrorEvent, ResponseFailedEvent, ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, ResponseStreamEvent, ResponseTextDeltaEvent, ResponseUsage as ResponsesUsage, Tool, }; use serde_json::{json, to_value}; use std::collections::HashSet; use swiftide_core::chat_completion::{ ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatMessageContentPart, ReasoningItem, ToolCall, ToolSpec, Usage, }; use crate::openai::{OpenAI, Options}; fn expect_emit( state: &mut ResponsesStreamState, event: ResponseStreamEvent, stream_full: bool, ) -> ResponsesStreamItem { state .apply_event(event, stream_full) .unwrap() .expect("expected emission") } fn expect_no_emit( state: &mut ResponsesStreamState, event: ResponseStreamEvent, stream_full: bool, ) { assert!( state.apply_event(event, stream_full).unwrap().is_none(), "expected no emission" ); } fn sample_usage() -> ResponsesUsage { ResponsesUsage { input_tokens: 5, input_tokens_details: InputTokenDetails { cached_tokens: 0 }, output_tokens: 3, output_tokens_details: OutputTokenDetails { reasoning_tokens: 0, }, total_tokens: 8, } } #[allow(dead_code)] #[derive(schemars::JsonSchema)] struct WeatherArgs { _city: String, } #[allow(dead_code)] #[derive(schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentArgs { request: NestedCommentRequest, } #[allow(dead_code)] #[derive(schemars::JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentRequest { #[serde(default, skip_serializing_if = "Option::is_none")] body: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] text: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] page_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] block_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] discussion_id: Option<String>, } fn sample_tool_spec() -> ToolSpec { ToolSpec::builder() .name("get_weather") .description("Retrieve weather data") .parameters_schema(schemars::schema_for!(WeatherArgs)) .build() .unwrap() } fn sample_tool_spec_named(name: &str) -> ToolSpec { ToolSpec::builder() .name(name) .description(format!("{name} description")) .parameters_schema(schemars::schema_for!(WeatherArgs)) .build() .unwrap() } #[test] fn test_user_parts_to_easy_input_content_with_image() { let parts = vec![ ChatMessageContentPart::text("Describe this image."), ChatMessageContentPart::image("https://example.com/image.png"), ]; let easy = user_parts_to_easy_input_content(&parts).expect("map user parts"); let value = to_value(easy).expect("serialize easy content"); let parts = value.as_array().expect("expected content list array"); assert_eq!(parts[0]["type"], "input_text"); assert_eq!(parts[0]["text"], "Describe this image."); assert_eq!(parts[1]["type"], "input_image"); assert_eq!(parts[1]["image_url"], "https://example.com/image.png"); assert_eq!(parts[1]["detail"], "auto"); } fn output_message(id: &str, parts: &[&str]) -> OutputMessage { OutputMessage { content: parts .iter() .map(|text| { OutputMessageContent::OutputText(OutputTextContent { annotations: Vec::new(), logprobs: None, text: (*text).to_string(), }) }) .collect(), id: id.to_string(), role: AssistantRole::Assistant, status: OutputStatus::Completed, } } fn response_with_message_tool_reasoning(message: &str) -> Response { let output_message = OutputItem::Message(output_message("msg", &[message])); let output = vec![ serde_json::to_value(output_message).expect("output message serializes"), json!({ "type": "function_call", "id": "call", "call_id": "call", "name": "metadata_tool", "arguments": "{\"ok\":true}", "status": "completed" }), json!({ "type": "reasoning", "id": "reasoning_meta", "summary": [ {"type": "summary_text", "text": "metadata summary"} ] }), ]; serde_json::from_value(json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1", "object": "response", "status": "completed", "output": output, "usage": sample_usage(), })) .expect("valid response json") } #[test] fn test_build_responses_request_includes_tools_and_options() { let openai = OpenAI::builder() .default_prompt_model("gpt-4.1") .parallel_tool_calls(Some(true)) .default_options( Options::builder() .metadata(json!({"tag": "demo"})) .user("tester") .temperature(0.2), ) .build() .unwrap(); let mut tools = HashSet::new(); tools.insert(sample_tool_spec_named("z_tool")); tools.insert(sample_tool_spec_named("a_tool")); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs(tools) .build() .unwrap(); let create = build_responses_request_from_chat(&openai, &request).unwrap(); assert_eq!(create.model.as_deref(), Some("gpt-4.1")); assert_eq!(create.temperature, Some(0.2)); assert_eq!(create.parallel_tool_calls, Some(true)); assert_eq!( create .metadata .as_ref() .and_then(|m| m.get("tag")) .map(String::as_str), Some("demo"), ); let InputParam::Items(items) = &create.input else { panic!("expected items input"); }; assert_eq!(items.len(), 1); let tools = create.tools.expect("tools present"); assert_eq!(tools.len(), 2); let tool_names = tools .iter() .map(|tool| match tool { Tool::Function(function) => function.name.as_str(), _ => panic!("expected function tool"), }) .collect::<Vec<_>>(); assert_eq!(tool_names, vec!["a_tool", "z_tool"]); assert_eq!( create.tool_choice, Some(ToolChoiceParam::Mode(ToolChoiceOptions::Auto)) ); } #[test] fn test_build_responses_request_sets_additional_properties_false_for_custom_tool_schema() { let openai = OpenAI::builder() .default_prompt_model("gpt-4.1") .build() .unwrap(); let mut tools = HashSet::new(); tools.insert(sample_tool_spec()); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs(tools) .build() .unwrap(); let create = build_responses_request_from_chat(&openai, &request).unwrap(); let tools = create.tools.expect("tools present"); assert_eq!(tools.len(), 1); let Tool::Function(function) = &tools[0] else { panic!("expected function tool"); }; let additional_properties = function .parameters .as_ref() .and_then(|params| params.get("additionalProperties").cloned()); #[allow(dead_code)] #[derive(schemars::JsonSchema)] #[serde(deny_unknown_fields)] #[schemars(title = "WeatherArgs")] struct WeatherArgsCorrect { _city: String, } let expected_parameters = serde_json::json!({ "type": "object", "title": "WeatherArgs", "properties": { "_city": { "type": "string" } }, "required": ["_city"], "additionalProperties": false }); assert_eq!( additional_properties, Some(serde_json::Value::Bool(false)), "OpenAI requires additionalProperties to be set to false for tool parameters, got {}", serde_json::to_string_pretty(&function.parameters).unwrap() ); assert_eq!(function.parameters, Some(expected_parameters)); } #[test] fn test_build_responses_request_sets_nested_required_for_typed_request_objects() { let openai = OpenAI::builder() .default_prompt_model("gpt-4.1") .build() .unwrap(); let mut tools = HashSet::new(); tools.insert( ToolSpec::builder() .name("notion_create_comment") .description("Create a comment") .parameters_schema(schemars::schema_for!(NestedCommentArgs)) .build() .unwrap(), ); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .tool_specs(tools) .build() .unwrap(); let create = build_responses_request_from_chat(&openai, &request).unwrap(); let tools = create.tools.expect("tools present"); let Tool::Function(function) = &tools[0] else { panic!("expected function tool"); }; let nested_required = function.parameters.as_ref().and_then(|schema| { let request_schema = schema .get("properties") .and_then(|value| value.get("request")) .and_then(serde_json::Value::as_object)?; let referenced_required = request_schema .get("$ref") .and_then(serde_json::Value::as_str) .and_then(|reference| reference.strip_prefix("#/$defs/")) .and_then(|definition_name| { schema .get("$defs") .and_then(|value| value.get(definition_name)) }) .and_then(|value| value.get("required")) .and_then(serde_json::Value::as_array); referenced_required.or_else(|| { request_schema .get("required") .and_then(serde_json::Value::as_array) }) }); let nested_required = nested_required.expect("nested request should have required"); let names: std::collections::HashSet<_> = nested_required .iter() .filter_map(serde_json::Value::as_str) .collect(); assert!(names.contains("body")); assert!(names.contains("text")); assert!(names.contains("page_id")); assert!(names.contains("block_id")); assert!(names.contains("discussion_id")); } #[test] fn test_build_responses_request_reasoning_is_stateless_with_summary_and_encrypted_content() { let openai = OpenAI::builder() .default_prompt_model("gpt-4.1") .default_options(Options::builder().reasoning_effort(ReasoningEffort::Low)) .build() .unwrap(); let request = ChatCompletionRequest::builder() .messages(vec![ChatMessage::User("hi".into())]) .build() .unwrap(); let create = build_responses_request_from_chat(&openai, &request).unwrap(); assert_eq!(create.store, Some(false)); assert_eq!( create.reasoning.as_ref().and_then(|r| r.summary), Some(ReasoningSummary::Auto) ); assert!( create .include .as_ref() .is_some_and(|items| items.contains(&IncludeEnum::ReasoningEncryptedContent)) ); } #[test] fn test_chat_messages_to_input_items_keeps_tool_calls_without_content() { let tool_call = ToolCall::builder() .id("call_123") .name("lookup") .maybe_args(Some("{\"q\":\"rust\"}".to_string())) .build() .unwrap(); let message = ChatMessage::Assistant(None, Some(vec![tool_call])); let items = chat_messages_to_input_items(&[message], true).expect("conversion succeeds"); assert_eq!(items.len(), 1); let InputItem::Item(async_openai::types::responses::Item::FunctionCall(function_call)) = &items[0] else { panic!("expected function call item"); }; assert_eq!(function_call.call_id, "fc_123"); assert_eq!(function_call.name, "lookup"); assert_eq!(function_call.arguments, "{\"q\":\"rust\"}"); assert_eq!(function_call.status, Some(OutputStatus::InProgress)); } #[test] fn test_chat_messages_to_input_items_includes_reasoning_with_encrypted_content() { let message = ChatMessage::Reasoning(ReasoningItem { id: "reasoning_1".to_string(), summary: vec!["First".to_string(), "Second".to_string()], encrypted_content: Some("encrypted".to_string()), ..Default::default() }); let items = chat_messages_to_input_items(&[message], true).expect("conversion succeeds"); assert_eq!(items.len(), 1); let InputItem::Item(async_openai::types::responses::Item::Reasoning(reasoning_item)) = &items[0] else { panic!("expected reasoning item"); }; assert_eq!(reasoning_item.id, "reasoning_1"); assert!(reasoning_item.summary.is_empty()); assert_eq!( reasoning_item.encrypted_content.as_deref(), Some("encrypted") ); } #[test] fn test_chat_messages_to_input_items_ignores_empty_assistant() { let message = ChatMessage::Assistant(None, None); let items = chat_messages_to_input_items(&[message], true).expect("conversion succeeds"); assert!(items.is_empty()); } #[test] fn test_tool_call_from_function_call_uses_id_when_call_id_missing() { let function_call = FunctionToolCall { arguments: String::new(), call_id: String::new(), name: "lookup".to_string(), id: Some("call_456".to_string()), status: Some(OutputStatus::Completed), }; let tool_call = tool_call_from_function_call(&function_call).expect("tool call"); assert_eq!(tool_call.id(), "call_456"); assert_eq!(tool_call.name(), "lookup"); assert!(tool_call.args().is_none()); } #[test] fn test_collect_message_text_helpers_join_parts() { let output = vec![ OutputItem::Message(output_message("msg_1", &["First", "Second"])), OutputItem::FunctionCall(FunctionToolCall { arguments: "{}".to_string(), call_id: "call".to_string(), name: "noop".to_string(), id: None, status: Some(OutputStatus::Completed), }), OutputItem::Message(output_message("msg_2", &["Third"])), ]; let collected = collect_message_text_from_items(&output).expect("text present"); assert_eq!(collected, "First\nSecond\nThird"); let message = output_message("msg_single", &["Line one", "Line two"]); let collected_message = collect_message_text_from_message(&message).expect("message text present"); assert_eq!(collected_message, "Line one\nLine two"); } #[test] fn test_metadata_to_chat_completion_respects_existing_fields() { let metadata = response_with_message_tool_reasoning("metadata message"); let mut empty = ChatCompletionResponse::default(); metadata_to_chat_completion(&metadata, &mut empty).expect("metadata applies"); assert_eq!(empty.message.as_deref(), Some("metadata message")); assert!(empty.tool_calls.is_some()); assert!(empty.reasoning.is_some()); assert!(empty.usage.is_some()); let existing_tool = ToolCall::builder() .id("existing") .name("existing_tool") .maybe_args(Some("{\"keep\":true}".to_string())) .build() .unwrap(); let existing_reasoning = ReasoningItem { id: "existing_reasoning".to_string(), summary: vec!["keep".to_string()], encrypted_content: None, ..Default::default() }; let existing_usage = Usage { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2, details: None, }; let mut existing = ChatCompletionResponse::builder() .message("existing message") .tool_calls(vec![existing_tool.clone()]) .reasoning(vec![existing_reasoning.clone()]) .usage(existing_usage) .build() .unwrap(); metadata_to_chat_completion(&metadata, &mut existing).expect("metadata applies"); assert_eq!(existing.message.as_deref(), Some("existing message")); assert_eq!( existing .tool_calls .as_ref() .and_then(|calls| calls.first()) .map(ToolCall::id), Some("existing") ); assert_eq!( existing .reasoning .as_ref() .and_then(|items| items.first()) .map(|item| item.id.as_str()), Some("existing_reasoning") ); assert_eq!( existing.usage.as_ref().map(|usage| usage.total_tokens), Some(sample_usage().total_tokens) ); } #[test] fn test_tool_output_preserves_structured_values() { let tool_call = ToolCall::builder() .id("fc_test") .name("demo") .maybe_args(Some("{\"ok\":true}".to_owned())) .build() .unwrap(); let messages = vec![ ChatMessage::ToolOutput( tool_call.clone(), ToolOutput::Stop(Some(json!({"foo": "bar"}))), ), ChatMessage::ToolOutput( tool_call.clone(), ToolOutput::FeedbackRequired(Some(json!({"nested": {"a": 1}}))), ), ChatMessage::ToolOutput( tool_call.clone(), ToolOutput::AgentFailed(Some(json!([1, 2, 3]))), ), ]; let items = chat_messages_to_input_items(&messages, true).expect("conversion succeeds"); assert_eq!(items.len(), 3); for (item, expected) in items .iter() .zip([r#"{"foo":"bar"}"#, r#"{"nested":{"a":1}}"#, r"[1,2,3]"]) { let InputItem::Item(async_openai::types::responses::Item::FunctionCallOutput( function_output, )) = item else { panic!("expected function call output item"); }; assert_eq!(function_output.call_id, "fc_test"); assert_eq!( function_output.output, FunctionCallOutput::Text(expected.to_string()) ); } } #[test] fn test_response_to_chat_completion_maps_outputs() { let usage = sample_usage(); let response: Response = serde_json::from_value(json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1", "object": "response", "status": "completed", "output": [ { "type": "message", "id": "msg", "role": "assistant", "status": "completed", "content": [ {"type": "output_text", "text": "Assistant reply", "annotations": []} ] }, { "type": "function_call", "id": "tool", "call_id": "tool", "name": "get_weather", "arguments": "{\"city\":\"Oslo\"}", "status": "completed" } ], "usage": usage, })) .expect("valid response json"); let completion = response_to_chat_completion(&response).unwrap(); assert_eq!(completion.message(), Some("Assistant reply")); let tool_calls = completion.tool_calls().expect("tool calls present"); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].name(), "get_weather"); assert_eq!(tool_calls[0].args(), Some("{\"city\":\"Oslo\"}")); let usage = completion.usage.expect("usage"); assert_eq!(usage.prompt_tokens, 5); assert_eq!(usage.completion_tokens, 3); assert_eq!(usage.total_tokens, 8); } #[test] fn test_response_to_chat_completion_collects_reasoning_summary_and_encrypted_content() { let usage = sample_usage(); let response: Response = serde_json::from_value(json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1", "object": "response", "status": "completed", "output": [ { "type": "reasoning", "id": "reasoning_1", "summary": [ {"type": "summary_text", "text": "First"}, {"type": "summary_text", "text": "Second"} ], "encrypted_content": "encrypted" } ], "usage": usage, })) .expect("valid response json"); let completion = response_to_chat_completion(&response).unwrap(); let reasoning = completion.reasoning.expect("reasoning items present"); assert_eq!(reasoning.len(), 1); assert_eq!(reasoning[0].id, "reasoning_1"); assert_eq!( reasoning[0].summary, vec!["First".to_string(), "Second".to_string()] ); assert_eq!(reasoning[0].encrypted_content.as_deref(), Some("encrypted")); } #[test] fn test_stream_accumulator_handles_text_and_tool_events() { let mut state = ResponsesStreamState::default(); let delta: ResponseTextDeltaEvent = serde_json::from_value(json!({ "sequence_number": 0, "item_id": "msg_1", "output_index": 0, "content_index": 0, "delta": "Hello" })) .unwrap(); let chunk = expect_emit( &mut state, ResponseStreamEvent::ResponseOutputTextDelta(delta), false, ); assert_eq!( chunk .response .delta .as_ref() .and_then(|d| d.message_chunk.as_deref()), Some("Hello") ); let item_added: ResponseOutputItemAddedEvent = serde_json::from_value(json!({ "sequence_number": 1, "output_index": 0, "item": { "type": "function_call", "id": "call", "call_id": "call", "name": "lookup", "arguments": "", "status": "in_progress" } })) .unwrap(); expect_emit( &mut state, ResponseStreamEvent::ResponseOutputItemAdded(item_added), false, ); let args_delta: ResponseFunctionCallArgumentsDeltaEvent = serde_json::from_value(json!({ "sequence_number": 2, "item_id": "call", "output_index": 0, "delta": "{\"q\":\"rust\"}" })) .unwrap(); expect_emit( &mut state, ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(args_delta), false, ); let args_done: ResponseFunctionCallArgumentsDoneEvent = serde_json::from_value(json!({ "sequence_number": 3, "item_id": "call", "output_index": 0, "name": "lookup", "arguments": "{\"q\":\"rust\"}" })) .unwrap(); expect_emit( &mut state, ResponseStreamEvent::ResponseFunctionCallArgumentsDone(args_done), false, ); let usage = sample_usage(); let completed: ResponseCompletedEvent = serde_json::from_value(json!({ "sequence_number": 4, "response": { "id": "resp", "object": "response", "created_at": 0, "status": "completed", "model": "gpt-4.1", "output": [], "usage": to_value(&usage).unwrap() } })) .unwrap(); let final_chunk = expect_emit( &mut state, ResponseStreamEvent::ResponseCompleted(completed), false, ); assert!(final_chunk.finished); assert_eq!(final_chunk.response.message(), Some("Hello")); let tool_calls = final_chunk .response .tool_calls() .expect("tool calls present"); assert_eq!(tool_calls[0].name(), "lookup"); assert_eq!(tool_calls[0].args(), Some("{\"q\":\"rust\"}")); let usage = final_chunk.response.usage.expect("usage"); assert_eq!(usage.total_tokens, 8); } #[test] fn test_stream_state_take_final_only_once() { let mut state = ResponsesStreamState::default(); assert!(state.take_final(true).is_some()); assert!(state.take_final(true).is_none()); } #[test] fn test_stream_state_ignores_events_after_completion() { let mut state = ResponsesStreamState::default(); let usage = sample_usage(); let completed: ResponseCompletedEvent = serde_json::from_value(json!({ "sequence_number": 0, "response": { "id": "resp", "object": "response", "created_at": 0, "status": "completed", "model": "gpt-4.1", "output": [], "usage": to_value(&usage).unwrap() } })) .unwrap(); let finished = expect_emit( &mut state, ResponseStreamEvent::ResponseCompleted(completed), false, ); assert!(finished.finished); let delta: ResponseTextDeltaEvent = serde_json::from_value(json!({ "sequence_number": 1, "item_id": "msg_1", "output_index": 0, "content_index": 0, "delta": "ignored" })) .unwrap(); expect_no_emit( &mut state, ResponseStreamEvent::ResponseOutputTextDelta(delta), false, ); } #[test] fn test_stream_state_message_item_added_collects_text() { let mut state = ResponsesStreamState::default(); let item_added: ResponseOutputItemAddedEvent = serde_json::from_value(json!({ "sequence_number": 0, "output_index": 0, "item": { "type": "message", "id": "msg", "role": "assistant", "status": "completed", "content": [ {"type": "output_text", "text": "Hello", "annotations": []}, {"type": "output_text", "text": "World", "annotations": []} ] } })) .unwrap(); let chunk = expect_emit( &mut state, ResponseStreamEvent::ResponseOutputItemAdded(item_added), true, ); assert_eq!(chunk.response.message(), Some("Hello\nWorld")); } #[test] fn test_stream_state_output_item_done_emits_tool_call() { let mut state = ResponsesStreamState::default(); let item_added: ResponseOutputItemAddedEvent = serde_json::from_value(json!({ "sequence_number": 0, "output_index": 0, "item": { "type": "function_call", "id": "call", "call_id": "call", "name": "lookup", "arguments": "", "status": "in_progress" } })) .unwrap(); expect_emit( &mut state, ResponseStreamEvent::ResponseOutputItemAdded(item_added), true, ); let done: ResponseOutputItemDoneEvent = serde_json::from_value(json!({ "sequence_number": 1, "output_index": 0, "item": { "type": "function_call", "id": "call-id", "call_id": "", "name": "lookup", "arguments": "", "status": "completed" } })) .unwrap(); let chunk = expect_emit( &mut state, ResponseStreamEvent::ResponseOutputItemDone(done), true, ); let calls = chunk.response.tool_calls().expect("tool calls present"); assert_eq!(calls[0].id(), "call"); assert_eq!(calls[0].name(), "lookup"); } #[test] fn test_stream_state_duplicate_arguments_done_no_emit() { let mut state = ResponsesStreamState::default(); let item_added: ResponseOutputItemAddedEvent = serde_json::from_value(json!({ "sequence_number": 0, "output_index": 0, "item": { "type": "function_call", "id": "call", "call_id": "call", "name": "lookup", "arguments": "", "status": "in_progress" } })) .unwrap(); expect_emit( &mut state, ResponseStreamEvent::ResponseOutputItemAdded(item_added), false, ); let args_delta: ResponseFunctionCallArgumentsDeltaEvent = serde_json::from_value(json!({ "sequence_number": 1, "item_id": "call", "output_index": 0, "delta": "{\"q\":1}" })) .unwrap(); expect_emit( &mut state, ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(args_delta), false, ); let args_done: ResponseFunctionCallArgumentsDoneEvent = serde_json::from_value(json!({ "sequence_number": 2, "item_id": "call", "output_index": 0, "arguments": "{\"q\":1}", "name": "" })) .unwrap(); expect_no_emit( &mut state, ResponseStreamEvent::ResponseFunctionCallArgumentsDone(args_done), false, ); } #[test] fn test_stream_state_response_failed_and_error() { let mut state = ResponsesStreamState::default(); let failed: ResponseFailedEvent = serde_json::from_value(json!({ "sequence_number": 0, "response": { "id": "resp", "object": "response", "created_at": 0, "status": "failed", "model": "gpt-4.1", "output": [], "error": {"code": "oops", "message": "boom"} } })) .unwrap(); let err = state .apply_event(ResponseStreamEvent::ResponseFailed(failed), false) .unwrap_err(); assert!( matches!(err, LanguageModelError::PermanentError(msg) if msg.to_string().contains("oops")) ); let mut state = ResponsesStreamState::default(); let err_event: ResponseErrorEvent = serde_json::from_value(json!({ "sequence_number": 1, "message": "bad things" })) .unwrap(); let err = state .apply_event(ResponseStreamEvent::ResponseError(err_event), false) .unwrap_err(); assert!( matches!(err, LanguageModelError::PermanentError(msg) if msg.to_string().contains("bad things")) ); } #[test] fn test_response_to_chat_completion_failed_status_errors() { let response: Response = serde_json::from_value(json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1", "object": "response", "status": "failed", "error": {"code": "oops", "message": "boom"}, "output": [] })) .unwrap(); let err = response_to_chat_completion(&response).unwrap_err(); assert!( matches!(err, LanguageModelError::PermanentError(msg) if msg.to_string().contains("oops")) ); } #[test] fn test_convert_metadata_rejects_non_string_values() { let metadata = json!({"tag": 123}); assert!(convert_metadata(&metadata).is_none()); } #[test] fn test_base_request_args_runs_with_seed_and_presence_penalty() { let openai = OpenAI::builder() .default_prompt_model("gpt-4.1") .default_options( Options::builder() .seed(7) .presence_penalty(0.4) .temperature(0.1), ) .build() .unwrap(); assert!(base_request_args(&openai, "gpt-4.1").is_ok()); } #[test] fn test_normalize_responses_function_call_id() { assert_eq!( normalize_responses_function_call_id("call_12345"), "fc_12345" ); assert_eq!(normalize_responses_function_call_id("fc_abc"), "fc_abc"); assert_eq!(normalize_responses_function_call_id("custom"), "custom"); } } ================================================ FILE: swiftide-integrations/src/openai/simple_prompt.rs ================================================ //! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct. //! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt //! processing and generating responses as part of the Swiftide system. use async_openai::types::chat::ChatCompletionRequestUserMessageArgs; use async_trait::async_trait; use swiftide_core::{ SimplePrompt, chat_completion::{Usage, errors::LanguageModelError}, prompt::Prompt, util::debug_long_utf8, }; use super::responses_api::{build_responses_request_from_prompt, response_to_chat_completion}; use crate::openai::openai_error_to_language_model_error; use super::GenericOpenAI; use anyhow::Result; /// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a /// response. #[async_trait] impl< C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone + 'static, > SimplePrompt for GenericOpenAI<C> { /// Sends a prompt to the `OpenAI` API and returns the response content. /// /// # Parameters /// - `prompt`: A string slice that holds the prompt to be sent to the `OpenAI` API. /// /// # Returns /// - `Result<String>`: On success, returns the content of the response as a `String`. On /// failure, returns an error wrapped in a `Result`. /// /// # Errors /// - Returns an error if the model is not set in the default options. /// - Returns an error if the request to the `OpenAI` API fails. /// - Returns an error if the response does not contain the expected content. #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))] #[cfg_attr( feature = "langfuse", tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION")) )] async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> { if self.is_responses_api_enabled() { return self.prompt_via_responses_api(prompt).await; } // Retrieve the model from the default options, returning an error if not set. let model = self .default_options .prompt_model .as_ref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?; // Build the request to be sent to the OpenAI API. let request = self .chat_completion_request_defaults() .model(model) .messages(vec![ ChatCompletionRequestUserMessageArgs::default() .content(prompt.render()?) .build() .map_err(LanguageModelError::permanent)? .into(), ]) .build() .map_err(LanguageModelError::permanent)?; // Log the request for debugging purposes. tracing::trace!( model = &model, messages = debug_long_utf8( serde_json::to_string_pretty(&request.messages.last()) .map_err(LanguageModelError::permanent)?, 100 ), "[SimplePrompt] Request to openai" ); // Send the request to the OpenAI API and await the response. // Move the request; we logged key fields above if needed. let tracking_request = request.clone(); let response = self .client .chat() .create(request) .await .map_err(openai_error_to_language_model_error)?; let message = response .choices .first() .and_then(|choice| choice.message.content.clone()) .ok_or_else(|| { LanguageModelError::PermanentError("Expected content in response".into()) })?; let usage = response.usage.as_ref().map(Usage::from); self.track_completion( model, usage.as_ref(), Some(&tracking_request), Some(&response), ); Ok(message) } } impl< C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone + 'static, > GenericOpenAI<C> { async fn prompt_via_responses_api(&self, prompt: Prompt) -> Result<String, LanguageModelError> { let prompt_text = prompt.render().map_err(LanguageModelError::permanent)?; let model = self .default_options .prompt_model .as_ref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?; let create_request = build_responses_request_from_prompt(self, prompt_text.clone())?; let response = self .client .responses() .create(create_request.clone()) .await .map_err(openai_error_to_language_model_error)?; let completion = response_to_chat_completion(&response)?; let message = completion.message.clone().ok_or_else(|| { LanguageModelError::PermanentError("Expected content in response".into()) })?; self.track_completion( model, completion.usage.as_ref(), Some(&create_request), Some(&completion), ); Ok(message) } } #[allow(clippy::items_after_statements)] #[cfg(test)] mod tests { use super::*; use crate::openai::OpenAI; use serde_json::Value; use wiremock::{ Mock, MockServer, Request, Respond, ResponseTemplate, matchers::{method, path}, }; #[test_log::test(tokio::test)] async fn test_prompt_errors_when_model_missing() { let openai = OpenAI::builder().build().unwrap(); let result = openai.prompt("hello".into()).await; assert!(matches!(result, Err(LanguageModelError::PermanentError(_)))); } #[test_log::test(tokio::test)] async fn test_prompt_via_responses_api_returns_message() { let mock_server = MockServer::start().await; let response_body = serde_json::json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1-mini", "object": "response", "status": "completed", "output": [ { "type": "message", "id": "msg", "role": "assistant", "status": "completed", "content": [ {"type": "output_text", "text": "Hello world", "annotations": []} ] } ], "usage": { "input_tokens": 4, "input_tokens_details": {"cached_tokens": 0}, "output_tokens": 2, "output_tokens_details": {"reasoning_tokens": 0}, "total_tokens": 6 } }); struct ValidatePromptRequest { response: Value, } impl Respond for ValidatePromptRequest { fn respond(&self, request: &Request) -> ResponseTemplate { let payload: Value = serde_json::from_slice(&request.body).unwrap(); assert_eq!(payload["model"], self.response["model"]); let items = payload["input"].as_array().expect("array input"); assert_eq!(items.len(), 1); assert_eq!(items[0]["type"], "message"); ResponseTemplate::new(200).set_body_json(self.response.clone()) } } Mock::given(method("POST")) .and(path("/responses")) .respond_with(ValidatePromptRequest { response: response_body, }) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let client = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(client) .default_prompt_model("gpt-4.1-mini") .use_responses_api(true) .build() .unwrap(); let result = openai.prompt("Say hi".into()).await.unwrap(); assert_eq!(result, "Hello world"); } #[test_log::test(tokio::test)] async fn test_prompt_via_responses_api_missing_output_errors() { let mock_server = MockServer::start().await; let empty_response = serde_json::json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1-mini", "object": "response", "output": [], "status": "completed" }); Mock::given(method("POST")) .and(path("/responses")) .respond_with(ResponseTemplate::new(200).set_body_json(empty_response)) .mount(&mock_server) .await; let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri()); let client = async_openai::Client::with_config(config); let openai = OpenAI::builder() .client(client) .default_prompt_model("gpt-4.1-mini") .use_responses_api(true) .build() .unwrap(); let err = openai.prompt("test".into()).await.unwrap_err(); assert!(matches!(err, LanguageModelError::PermanentError(_))); } } ================================================ FILE: swiftide-integrations/src/openai/structured_prompt.rs ================================================ //! This module provides an implementation of the `StructuredPrompt` trait for the `OpenAI` struct. //! //! Unlike the other traits, `StructuredPrompt` is *not* dyn safe. //! //! Use `DynStructuredPrompt` if you need dyn dispatch. For custom implementations, if you //! implement `DynStructuredPrompt`, you get `StructuredPrompt` for free. use async_openai::types::{ chat::ChatCompletionRequestUserMessageArgs, responses::{ResponseFormat, ResponseFormatJsonSchema}, }; use async_trait::async_trait; use schemars::Schema; use swiftide_core::{ DynStructuredPrompt, chat_completion::{Usage, errors::LanguageModelError}, prompt::Prompt, util::debug_long_utf8, }; use super::responses_api::{ build_responses_request_from_prompt_with_schema, response_to_chat_completion, }; use crate::openai::openai_error_to_language_model_error; use super::GenericOpenAI; use anyhow::{Context as _, Result}; /// The `StructuredPrompt` trait defines a method for sending a prompt to an AI model and receiving /// a response. #[async_trait] impl< C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone + 'static, > DynStructuredPrompt for GenericOpenAI<C> { /// Sends a prompt to the `OpenAI` API and returns the response content. /// /// # Parameters /// - `prompt`: A string slice that holds the prompt to be sent to the `OpenAI` API. /// /// # Returns /// - `Result<String>`: On success, returns the content of the response as a `String`. On /// failure, returns an error wrapped in a `Result`. /// /// # Errors /// - Returns an error if the model is not set in the default options. /// - Returns an error if the request to the `OpenAI` API fails. /// - Returns an error if the response does not contain the expected content. #[tracing::instrument(skip_all, err)] #[cfg_attr( feature = "langfuse", tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION")) )] async fn structured_prompt_dyn( &self, prompt: Prompt, schema: Schema, ) -> Result<serde_json::Value, LanguageModelError> { if self.is_responses_api_enabled() { return self .structured_prompt_via_responses_api(prompt, schema) .await; } // Retrieve the model from the default options, returning an error if not set. let model = self .default_options .prompt_model .as_ref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?; let schema_value = serde_json::to_value(&schema).context("Failed to get schema as value")?; let response_format = ResponseFormat::JsonSchema { json_schema: ResponseFormatJsonSchema { description: None, name: "structured_prompt".into(), schema: Some(schema_value), strict: Some(true), }, }; // Build the request to be sent to the OpenAI API. let request = self .chat_completion_request_defaults() .model(model) .response_format(response_format) .messages(vec![ ChatCompletionRequestUserMessageArgs::default() .content(prompt.render()?) .build() .map_err(LanguageModelError::permanent)? .into(), ]) .build() .map_err(LanguageModelError::permanent)?; // Log the request for debugging purposes. tracing::trace!( model = &model, messages = debug_long_utf8( serde_json::to_string_pretty(&request.messages.last()) .map_err(LanguageModelError::permanent)?, 100 ), "[StructuredPrompt] Request to openai" ); // Send the request to the OpenAI API and await the response. let response = self .client .chat() .create(request.clone()) .await .map_err(openai_error_to_language_model_error)?; let message = response .choices .first() .and_then(|choice| choice.message.content.clone()) .ok_or_else(|| { LanguageModelError::PermanentError("Expected content in response".into()) })?; let usage = response.usage.as_ref().map(Usage::from); self.track_completion(model, usage.as_ref(), Some(&request), Some(&response)); let parsed = serde_json::from_str(&message) .with_context(|| format!("Failed to parse response\n {message}"))?; // Extract and return the content of the response, returning an error if not found. Ok(parsed) } } impl< C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone + 'static, > GenericOpenAI<C> { async fn structured_prompt_via_responses_api( &self, prompt: Prompt, schema: Schema, ) -> Result<serde_json::Value, LanguageModelError> { let prompt_text = prompt.render().map_err(LanguageModelError::permanent)?; let model = self .default_options .prompt_model .as_ref() .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?; let schema_value = serde_json::to_value(&schema) .context("Failed to get schema as value") .map_err(LanguageModelError::permanent)?; let create_request = build_responses_request_from_prompt_with_schema( self, prompt_text.clone(), schema_value, )?; let tracking_request = create_request.clone(); let response = self .client .responses() .create(create_request) .await .map_err(openai_error_to_language_model_error)?; let completion = response_to_chat_completion(&response)?; let message = completion.message.clone().ok_or_else(|| { LanguageModelError::PermanentError("Expected content in response".into()) })?; self.track_completion( model, completion.usage.as_ref(), Some(&tracking_request), Some(&completion), ); let parsed = serde_json::from_str(&message) .with_context(|| format!("Failed to parse response\n {message}")) .map_err(LanguageModelError::permanent)?; Ok(parsed) } } #[cfg(test)] mod tests { use crate::openai::{self, OpenAI}; use swiftide_core::StructuredPrompt; use super::*; use async_openai::Client; use async_openai::config::OpenAIConfig; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::json; use wiremock::{ Mock, MockServer, ResponseTemplate, matchers::{method, path}, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] struct SimpleOutput { answer: String, } async fn setup_client() -> (MockServer, OpenAI) { // Start the Wiremock server let mock_server = MockServer::start().await; // Prepare the response the mock should return let assistant_msg = serde_json::json!({ "role": "assistant", "content": serde_json::to_string(&SimpleOutput { answer: "42".to_owned() }).unwrap(), }); let body = serde_json::json!({ "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT", "object": "chat.completion", "created": 123, "model": "gpt-4.1-2025-04-14", "choices": [ { "index": 0, "message": assistant_msg, "logprobs": null, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29, "prompt_tokens_details": { "cached_tokens": 0, "audio_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "service_tier": "default" }); Mock::given(method("POST")) .and(path("/chat/completions")) .respond_with(ResponseTemplate::new(200).set_body_json(body)) .mount(&mock_server) .await; // Point our client at the mock server let config = OpenAIConfig::new().with_api_base(mock_server.uri()); let client = Client::with_config(config); // Construct the GenericOpenAI instance let opts = openai::Options { prompt_model: Some("gpt-4".to_string()), ..openai::Options::default() }; ( mock_server, OpenAI::builder() .client(client) .default_options(opts) .build() .unwrap(), ) } #[tokio::test] async fn test_structured_prompt_with_wiremock() { let (_guard, ai) = setup_client().await; // Call structured_prompt let result: serde_json::Value = ai.structured_prompt("test".into()).await.unwrap(); dbg!(&result); // Assert assert_eq!( serde_json::from_value::<SimpleOutput>(result).unwrap(), SimpleOutput { answer: "42".into() } ); } #[tokio::test] async fn test_structured_prompt_with_wiremock_as_box() { let (_guard, ai) = setup_client().await; // Call structured_prompt let ai: Box<dyn DynStructuredPrompt> = Box::new(ai); let result: serde_json::Value = ai .structured_prompt_dyn("test".into(), schema_for!(SimpleOutput)) .await .unwrap(); dbg!(&result); // Assert assert_eq!( serde_json::from_value::<SimpleOutput>(result).unwrap(), SimpleOutput { answer: "42".into() } ); } #[test_log::test(tokio::test)] async fn test_structured_prompt_via_responses_api() { let mock_server = MockServer::start().await; let response_body = json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1-mini", "object": "response", "status": "completed", "output": [ { "type": "message", "id": "msg", "role": "assistant", "status": "completed", "content": [ {"type": "output_text", "text": serde_json::to_string(&SimpleOutput { answer: "structured".into() }).unwrap(), "annotations": []} ] } ], "usage": { "input_tokens": 10, "input_tokens_details": {"cached_tokens": 0}, "output_tokens": 4, "output_tokens_details": {"reasoning_tokens": 0}, "total_tokens": 14 } }); Mock::given(method("POST")) .and(path("/responses")) .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) .mount(&mock_server) .await; let config = OpenAIConfig::new().with_api_base(mock_server.uri()); let client = Client::with_config(config); let openai = OpenAI::builder() .client(client) .default_prompt_model("gpt-4.1-mini") .use_responses_api(true) .build() .unwrap(); let schema = schema_for!(SimpleOutput); let result = openai .structured_prompt_dyn("Render".into(), schema) .await .unwrap(); assert_eq!( serde_json::from_value::<SimpleOutput>(result).unwrap(), SimpleOutput { answer: "structured".into(), } ); } #[test_log::test(tokio::test)] async fn test_structured_prompt_via_responses_api_invalid_json_errors() { let mock_server = MockServer::start().await; let bad_response = json!({ "created_at": 0, "id": "resp", "model": "gpt-4.1-mini", "object": "response", "status": "completed", "output": [ { "type": "message", "id": "msg", "role": "assistant", "status": "completed", "content": [ {"type": "output_text", "text": "not json", "annotations": []} ] } ] }); Mock::given(method("POST")) .and(path("/responses")) .respond_with(ResponseTemplate::new(200).set_body_json(bad_response)) .mount(&mock_server) .await; let config = OpenAIConfig::new().with_api_base(mock_server.uri()); let client = Client::with_config(config); let openai = OpenAI::builder() .client(client) .default_prompt_model("gpt-4.1-mini") .use_responses_api(true) .build() .unwrap(); let schema = schema_for!(SimpleOutput); let err = openai .structured_prompt_dyn("Render".into(), schema) .await .unwrap_err(); assert!(matches!(err, LanguageModelError::PermanentError(_))); } } ================================================ FILE: swiftide-integrations/src/openai/tool_schema.rs ================================================ use serde_json::{Map, Value}; use swiftide_core::chat_completion::{ToolSpec, ToolSpecError}; use thiserror::Error; type SchemaNormalizer = fn(&mut Value) -> Result<(), OpenAiToolSchemaError>; type SchemaValidator = fn(&Value) -> Result<(), OpenAiToolSchemaError>; #[derive(Debug)] pub(super) struct OpenAiToolSchema(Value); impl OpenAiToolSchema { pub(super) fn into_value(self) -> Value { self.0 } } impl TryFrom<&ToolSpec> for OpenAiToolSchema { type Error = OpenAiToolSchemaError; fn try_from(spec: &ToolSpec) -> Result<Self, Self::Error> { let value = OpenAiSchemaPipeline::apply(spec.canonical_parameters_schema_json()?)?; Ok(Self(value)) } } #[derive(Debug, Error)] pub(super) enum OpenAiToolSchemaError { #[error("{0}")] InvalidParametersSchema(String), #[error("OpenAI strict tool schemas do not support `{keyword}` at {path}")] UnsupportedKeyword { path: String, keyword: &'static str }, #[error("OpenAI strict tool schemas do not support array-valued `type` at {path}")] UnsupportedTypeUnion { path: String }, } impl From<ToolSpecError> for OpenAiToolSchemaError { fn from(value: ToolSpecError) -> Self { Self::InvalidParametersSchema(value.to_string()) } } struct OpenAiSchemaPipeline; impl OpenAiSchemaPipeline { fn apply(mut schema: Value) -> Result<Value, OpenAiToolSchemaError> { for normalizer in [ strip_schema_metadata as SchemaNormalizer, strip_rust_numeric_formats, complete_required_arrays, ] { normalizer(&mut schema)?; } { let validator = validate_openai_compatibility as SchemaValidator; validator(&schema)?; } Ok(schema) } } fn strip_schema_metadata(schema: &mut Value) -> Result<(), OpenAiToolSchemaError> { walk_schema_mut(schema, &SchemaPath::root(), &mut |node, _| { node.remove("$schema"); Ok(()) }) } fn strip_rust_numeric_formats(schema: &mut Value) -> Result<(), OpenAiToolSchemaError> { walk_schema_mut(schema, &SchemaPath::root(), &mut |node, _| { let should_strip = node .get("format") .and_then(Value::as_str) .is_some_and(is_rust_numeric_format); if should_strip { node.remove("format"); } Ok(()) }) } fn complete_required_arrays(schema: &mut Value) -> Result<(), OpenAiToolSchemaError> { walk_schema_mut(schema, &SchemaPath::root(), &mut |node, _| { let Some(properties) = node.get("properties").and_then(Value::as_object) else { return Ok(()); }; node.insert( "required".to_string(), Value::Array(properties.keys().cloned().map(Value::String).collect()), ); Ok(()) }) } fn validate_openai_compatibility(schema: &Value) -> Result<(), OpenAiToolSchemaError> { walk_schema(schema, &SchemaPath::root(), &mut |node, path| { if node.contains_key("oneOf") { return Err(OpenAiToolSchemaError::UnsupportedKeyword { path: path.to_string(), keyword: "oneOf", }); } if matches!(node.get("type"), Some(Value::Array(_))) { return Err(OpenAiToolSchemaError::UnsupportedTypeUnion { path: path.to_string(), }); } Ok(()) }) } fn is_rust_numeric_format(format: &str) -> bool { matches!( format, "int8" | "int16" | "int32" | "int64" | "int128" | "isize" | "uint" | "uint8" | "uint16" | "uint32" | "uint64" | "uint128" | "usize" ) } fn walk_schema_mut( value: &mut Value, path: &SchemaPath, visitor: &mut impl FnMut(&mut Map<String, Value>, &SchemaPath) -> Result<(), OpenAiToolSchemaError>, ) -> Result<(), OpenAiToolSchemaError> { let Value::Object(node) = value else { return Ok(()); }; visitor(node, path)?; walk_schema_children_mut(node, path, visitor) } fn walk_schema_children_mut( node: &mut Map<String, Value>, path: &SchemaPath, visitor: &mut impl FnMut(&mut Map<String, Value>, &SchemaPath) -> Result<(), OpenAiToolSchemaError>, ) -> Result<(), OpenAiToolSchemaError> { for key in ["items", "contains", "if", "then", "else", "not"] { if let Some(child) = node.get_mut(key) { walk_schema_mut(child, &path.with_key(key), visitor)?; } } for key in ["anyOf", "oneOf", "allOf", "prefixItems"] { let Some(entries) = node.get_mut(key).and_then(Value::as_array_mut) else { continue; }; for (index, child) in entries.iter_mut().enumerate() { walk_schema_mut(child, &path.with_index(key, index), visitor)?; } } for key in ["properties", "$defs", "definitions", "dependentSchemas"] { let Some(entries) = node.get_mut(key).and_then(Value::as_object_mut) else { continue; }; for (entry_key, child) in entries.iter_mut() { walk_schema_mut(child, &path.with_key(key).with_key(entry_key), visitor)?; } } Ok(()) } fn walk_schema( value: &Value, path: &SchemaPath, visitor: &mut impl FnMut(&Map<String, Value>, &SchemaPath) -> Result<(), OpenAiToolSchemaError>, ) -> Result<(), OpenAiToolSchemaError> { let Value::Object(node) = value else { return Ok(()); }; visitor(node, path)?; walk_schema_children(node, path, visitor) } fn walk_schema_children( node: &Map<String, Value>, path: &SchemaPath, visitor: &mut impl FnMut(&Map<String, Value>, &SchemaPath) -> Result<(), OpenAiToolSchemaError>, ) -> Result<(), OpenAiToolSchemaError> { for key in ["items", "contains", "if", "then", "else", "not"] { if let Some(child) = node.get(key) { walk_schema(child, &path.with_key(key), visitor)?; } } for key in ["anyOf", "oneOf", "allOf", "prefixItems"] { let Some(entries) = node.get(key).and_then(Value::as_array) else { continue; }; for (index, child) in entries.iter().enumerate() { walk_schema(child, &path.with_index(key, index), visitor)?; } } for key in ["properties", "$defs", "definitions", "dependentSchemas"] { let Some(entries) = node.get(key).and_then(Value::as_object) else { continue; }; for (entry_key, child) in entries { walk_schema(child, &path.with_key(key).with_key(entry_key), visitor)?; } } Ok(()) } #[derive(Clone, Debug)] struct SchemaPath(Vec<String>); impl SchemaPath { fn root() -> Self { Self(vec!["$".to_string()]) } fn with_key(&self, key: impl Into<String>) -> Self { let mut path = self.0.clone(); path.push(key.into()); Self(path) } fn with_index(&self, key: impl Into<String>, index: usize) -> Self { let mut path = self.0.clone(); path.push(key.into()); path.push(index.to_string()); Self(path) } } impl std::fmt::Display for SchemaPath { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.join(".")) } } #[cfg(test)] mod tests { use schemars::JsonSchema; use serde_json::json; use swiftide_core::chat_completion::ToolSpec; use super::OpenAiToolSchema; #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentArgs { request: NestedCommentRequest, } #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] struct NestedCommentRequest { #[serde(default, skip_serializing_if = "Option::is_none")] body: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] text: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] page_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] block_id: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")] discussion_id: Option<String>, } #[test] fn openai_tool_schema_strips_schema_metadata_and_rust_formats() { let spec = ToolSpec::builder() .name("comment") .description("Create a comment") .parameters_schema( serde_json::from_value::<schemars::Schema>(json!({ "$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object", "properties": { "page_size": { "type": ["integer", "null"], "format": "uint", "minimum": 0 } } })) .unwrap(), ) .build() .unwrap(); let schema = OpenAiToolSchema::try_from(&spec).unwrap().into_value(); assert!(schema.get("$schema").is_none()); assert_eq!( schema["properties"]["page_size"]["anyOf"], json!([ { "type": "integer", "minimum": 0 }, { "type": "null" } ]) ); } #[test] fn openai_tool_schema_adds_recursive_required_arrays() { let spec = ToolSpec::builder() .name("comment") .description("Create a comment") .parameters_schema(schemars::schema_for!(NestedCommentArgs)) .build() .unwrap(); let schema = OpenAiToolSchema::try_from(&spec).unwrap().into_value(); let nested_ref = schema["properties"]["request"]["$ref"] .as_str() .expect("nested request should be referenced"); let nested_name = nested_ref .rsplit('/') .next() .expect("nested request ref name"); assert_eq!( schema["$defs"][nested_name]["required"], json!(["block_id", "body", "discussion_id", "page_id", "text"]) ); } #[test] fn openai_tool_schema_rejects_non_nullable_one_of() { let spec = ToolSpec::builder() .name("comment") .description("Create a comment") .parameters_schema( serde_json::from_value::<schemars::Schema>(json!({ "type": "object", "properties": { "content": { "oneOf": [ { "type": "string" }, { "type": "integer" } ] } } })) .unwrap(), ) .build() .unwrap(); let error = OpenAiToolSchema::try_from(&spec).expect_err("oneOf should be rejected"); assert!(error.to_string().contains("`oneOf`")); } } ================================================ FILE: swiftide-integrations/src/parquet/loader.rs ================================================ use anyhow::{Context as _, Result}; use arrow_array::{LargeStringArray, StringArray, StringViewArray}; use fs_err::tokio::File; use futures_util::StreamExt as _; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; use swiftide_core::{ Loader, indexing::{IndexingStream, TextNode}, }; use tokio::runtime::Handle; use super::Parquet; impl Loader for Parquet { type Output = String; fn into_stream(self) -> IndexingStream<String> { let mut builder = tokio::task::block_in_place(|| { Handle::current().block_on(async { let file = File::open(self.path).await.expect("Failed to open file"); ParquetRecordBatchStreamBuilder::new(file) .await .context("Failed to load builder") .unwrap() .with_batch_size(self.batch_size) }) }); let file_metadata = builder.metadata().file_metadata().clone(); dbg!(file_metadata.schema_descr().columns()); let column_idx = file_metadata .schema() .get_fields() .iter() .enumerate() .find_map(|(pos, column)| { if self.column_name == column.name() { Some(pos) } else { None } }) .unwrap_or_else(|| panic!("Column {} not found in dataset", &self.column_name)); let mask = ProjectionMask::roots(file_metadata.schema_descr(), [column_idx]); builder = builder.with_projection(mask); let stream = builder.build().expect("Failed to build parquet builder"); let swiftide_stream = stream.flat_map_unordered(None, move |result_batch| { let Ok(batch) = result_batch else { let new_result: Result<TextNode> = Err(anyhow::anyhow!(result_batch.unwrap_err())); return vec![new_result].into(); }; assert!(batch.num_columns() == 1, "Number of columns _must_ be 1"); let column = batch.column(0); // Should only have one column at this point let node_values = if let Some(values) = column.as_any().downcast_ref::<StringArray>() { values .iter() .flatten() .map(TextNode::from) .map(Ok) .collect::<Vec<_>>() } else if let Some(values) = column.as_any().downcast_ref::<LargeStringArray>() { values .iter() .flatten() .map(TextNode::from) .map(Ok) .collect::<Vec<_>>() } else if let Some(values) = column.as_any().downcast_ref::<StringViewArray>() { values .iter() .flatten() .map(TextNode::from) .map(Ok) .collect::<Vec<_>>() } else { let new_result: Result<TextNode> = Err(anyhow::anyhow!( "Parquet column is not a string array (got {:?})", column.data_type() )); return vec![new_result].into(); }; IndexingStream::iter(node_values) }); swiftide_stream.boxed().into() // let mask = ProjectionMask:: } fn into_stream_boxed(self: Box<Self>) -> IndexingStream<String> { self.into_stream() } } #[cfg(test)] mod tests { use std::path::PathBuf; use futures_util::TryStreamExt as _; use super::*; #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_parquet_loader() { let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); path.push("src/parquet/test.parquet"); dbg!(&path); let loader = Parquet::builder() .path(path) .column_name("chunk") .build() .unwrap(); let result = loader.into_stream().try_collect::<Vec<_>>().await.unwrap(); let expected = [TextNode::new("hello"), TextNode::new("world")]; assert_eq!(result, expected); } } ================================================ FILE: swiftide-integrations/src/parquet/mod.rs ================================================ //! Stream data from parquet files use std::path::PathBuf; use derive_builder::Builder; pub mod loader; /// Stream data from parquet files on a single column /// /// Provide a path, column and optional batch size. The column must be of type `StringArray`. Then /// the column is loaded into the chunks of the Node. /// /// # Panics /// /// The loader can panic during initialization if anything with parquet or arrow fails before /// starting the stream. #[derive(Debug, Clone, Builder)] #[builder(setter(into, strip_option))] pub struct Parquet { path: PathBuf, column_name: String, #[builder(default = "1024")] batch_size: usize, } impl Parquet { pub fn builder() -> ParquetBuilder { ParquetBuilder::default() } } ================================================ FILE: swiftide-integrations/src/pgvector/fixtures.rs ================================================ //! Test fixtures and utilities for pgvector integration testing. //! //! Provides test infrastructure and helper types to verify vector storage and retrieval: //! - Mock data generation for different embedding modes //! - Test containers for `PostgreSQL` with pgvector extension //! - Common test scenarios and assertions //! //! # Examples //! //! ```rust //! use swiftide_integrations::pgvector::fixtures::{TestContext, PgVectorTestData}; //! use swiftide_core::indexing::{EmbedMode, EmbeddedField}; //! //! # async fn example() -> Result<(), Box<dyn std::error::Error>> { //! // Initialize test context with PostgreSQL container //! let context = TestContext::setup_with_cfg( //! Some(vec!["category", "priority"]), //! vec![EmbeddedField::Combined].into_iter().collect() //! ).await?; //! //! // Create test data for different embedding modes //! let test_data = PgVectorTestData { //! embed_mode: EmbedMode::SingleWithMetadata, //! chunk: "test content", //! metadata: None, //! vectors: vec![PgVectorTestData::create_test_vector( //! EmbeddedField::Combined, //! 1.0 //! )], //! }; //! # Ok(()) //! # } //! ``` //! //! The module supports testing for: //! - Single embedding with/without metadata //! - Per-field embeddings //! - Combined embedding modes //! - Different vector configurations //! - Various metadata scenarios use crate::pgvector::PgVector; use std::collections::HashSet; use swiftide_core::{ Persist, indexing::{self, EmbeddedField}, }; use testcontainers::{ContainerAsync, GenericImage}; /// Test data structure for pgvector integration testing. /// /// Provides a flexible structure to test different embedding modes and configurations, /// including metadata handling and vector generation. /// /// # Examples /// /// ```rust /// use swiftide_integrations::pgvector::fixtures::PgVectorTestData; /// use swiftide_core::indexing::{EmbedMode, EmbeddedField}; /// /// let test_data = PgVectorTestData { /// embed_mode: EmbedMode::SingleWithMetadata, /// chunk: "test content", /// metadata: None, /// vectors: vec![PgVectorTestData::create_test_vector( /// EmbeddedField::Combined, /// 1.0 /// )], /// }; /// ``` #[derive(Clone)] pub(crate) struct PgVectorTestData<'a> { /// Embedding mode for the test case pub embed_mode: indexing::EmbedMode, /// Test content chunk pub chunk: &'a str, /// Optional metadata for testing metadata handling pub metadata: Option<indexing::Metadata>, /// Vector embeddings with their corresponding fields pub vectors: Vec<(indexing::EmbeddedField, Vec<f32>)>, pub expected_in_results: bool, } impl PgVectorTestData<'_> { pub(crate) fn to_node(&self) -> indexing::TextNode { // Create the initial builder let mut base_builder = indexing::TextNode::builder(); // Set the required fields let mut builder = base_builder.chunk(self.chunk).embed_mode(self.embed_mode); // Add metadata if it exists if let Some(metadata) = &self.metadata { builder = builder.metadata(metadata.clone()); } // Build the node and add vectors let mut node = builder.build().unwrap(); node.vectors = Some(self.vectors.clone().into_iter().collect()); node } pub(crate) fn create_test_vector( field: EmbeddedField, base_value: f32, ) -> (EmbeddedField, Vec<f32>) { (field, vec![base_value; 384]) } } /// Test context managing `PostgreSQL` container and pgvector storage. /// /// Handles the lifecycle of test containers and provides configured storage /// instances for testing. /// /// # Examples /// /// ```rust /// # use swiftide_integrations::pgvector::fixtures::TestContext; /// # use swiftide_core::indexing::EmbeddedField; /// # async fn example() -> Result<(), Box<dyn std::error::Error>> { /// // Setup test context with specific configuration /// let context = TestContext::setup_with_cfg( /// Some(vec!["category"]), /// vec![EmbeddedField::Combined].into_iter().collect() /// ).await?; /// /// // Use context for testing /// context.pgv_storage.setup().await?; /// # Ok(()) /// # } /// ``` pub(crate) struct TestContext { /// Configured pgvector storage instance pub(crate) pgv_storage: PgVector, /// Container instance running `PostgreSQL` with pgvector _pgv_db_container: ContainerAsync<GenericImage>, } impl TestContext { /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage /// with configurable metadata fields pub(crate) async fn setup_with_cfg( metadata_fields: Option<Vec<&str>>, vector_fields: HashSet<EmbeddedField>, ) -> Result<Self, Box<dyn std::error::Error>> { // Start `PostgreSQL` container and obtain the connection URL let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; tracing::info!("Postgres database URL: {:#?}", pgv_db_url); // Initialize the connection pool outside of the builder chain let mut connection_pool = PgVector::builder(); // Configure PgVector storage let mut builder = connection_pool .db_url(pgv_db_url) .vector_size(384) .table_name("swiftide_pgvector_test".to_string()); // Add all vector fields for vector_field in vector_fields { builder = builder.with_vector(vector_field); } // Add all metadata fields if let Some(metadata_fields_inner) = metadata_fields { for field in metadata_fields_inner { builder = builder.with_metadata(field); } } let pgv_storage = builder.build().map_err(|err| { tracing::error!("Failed to build PgVector: {}", err); err })?; // Set up PgVector storage (create the table if not exists) pgv_storage.setup().await.map_err(|err| { tracing::error!("PgVector setup failed: {}", err); err })?; Ok(Self { pgv_storage, _pgv_db_container: pgv_db_container, }) } } ================================================ FILE: swiftide-integrations/src/pgvector/mod.rs ================================================ //! Integration module for `PostgreSQL` vector database (pgvector) operations. //! //! This module provides a client interface for vector similarity search operations using pgvector, //! supporting: //! - Vector collection management with configurable schemas //! - Efficient vector storage and indexing //! - Connection pooling with automatic retries //! - Batch operations for optimized performance //! - Metadata included in retrieval //! //! The functionality is primarily used through the [`PgVector`] client, which implements //! the [`Persist`] trait for seamless integration with indexing and query pipelines. //! //! # Example //! ```rust //! # use swiftide_integrations::pgvector::PgVector; //! # async fn example() -> anyhow::Result<()> { //! let client = PgVector::builder() //! .db_url("postgresql://localhost:5432/vectors") //! .vector_size(384) //! .build()?; //! //! # Ok(()) //! # } //! ``` #[cfg(test)] mod fixtures; mod persist; mod pgv_table_types; mod retrieve; use anyhow::Result; use derive_builder::Builder; use sqlx::PgPool; use std::fmt; use std::sync::Arc; use std::sync::OnceLock; use tokio::time::Duration; pub use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig}; /// Default maximum connections for the database connection pool. const DB_POOL_CONN_MAX: u32 = 10; /// Default maximum retries for database connection attempts. const DB_POOL_CONN_RETRY_MAX: u32 = 3; /// Delay between connection retry attempts, in seconds. const DB_POOL_CONN_RETRY_DELAY_SECS: u64 = 3; /// Default batch size for storing nodes. const BATCH_SIZE: usize = 50; /// Represents a Pgvector client with configuration options. /// /// This struct is used to interact with the Pgvector vector database, providing methods to manage /// vector collections, store data, and ensure efficient searches. The client can be cloned with low /// cost as it shares connections. #[derive(Builder, Clone)] #[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] pub struct PgVector { /// Name of the table to store vectors. #[builder(default = "String::from(\"swiftide_pgv_store\")")] table_name: String, /// Default vector size; can be customized per configuration. vector_size: i32, /// Batch size for storing nodes. #[builder(default = "BATCH_SIZE")] batch_size: usize, /// Field configurations for the `PgVector` table schema. /// /// Supports multiple field types (see [`FieldConfig`]). #[builder(default)] fields: Vec<FieldConfig>, /// Database connection URL. db_url: String, /// Maximum connections allowed in the connection pool. #[builder(default = "DB_POOL_CONN_MAX")] db_max_connections: u32, /// Maximum retry attempts for establishing a database connection. #[builder(default = "DB_POOL_CONN_RETRY_MAX")] db_max_retry: u32, /// Delay between retry attempts for database connections. #[builder(default = "Duration::from_secs(DB_POOL_CONN_RETRY_DELAY_SECS)")] db_conn_retry_delay: Duration, /// Lazy-initialized database connection pool. #[builder(default = "Arc::new(OnceLock::new())")] connection_pool: Arc<OnceLock<PgPool>>, /// SQL statement used for executing bulk insert. #[builder(default = "Arc::new(OnceLock::new())")] sql_stmt_bulk_insert: Arc<OnceLock<String>>, } impl fmt::Debug for PgVector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PgVector") .field("table_name", &self.table_name) .field("vector_size", &self.vector_size) .field("batch_size", &self.batch_size) .finish() } } impl PgVector { /// Creates a new instance of `PgVectorBuilder` with default settings. /// /// # Returns /// /// A new `PgVectorBuilder`. pub fn builder() -> PgVectorBuilder { PgVectorBuilder::default() } /// Retrieves a connection pool for `PostgreSQL`. /// /// This function returns the connection pool used for interacting with the `PostgreSQL` /// database. It fetches the pool from the `PgDBConnectionPool` struct. /// /// # Returns /// /// A `Result` that, on success, contains the `PgPool` representing the database connection /// pool. On failure, an error is returned. /// /// # Errors /// /// This function will return an error if it fails to retrieve the connection pool, which could /// occur if the underlying connection to `PostgreSQL` has not been properly established. pub async fn get_pool(&self) -> Result<&PgPool> { self.pool_get_or_initialize().await } pub fn get_table_name(&self) -> &str { &self.table_name } } impl PgVectorBuilder { /// Adds a vector configuration to the builder. /// /// # Arguments /// /// * `config` - The vector configuration to add, which can be converted into a `VectorConfig`. /// /// # Returns /// /// A mutable reference to the builder with the new vector configuration added. pub fn with_vector(&mut self, config: impl Into<VectorConfig>) -> &mut Self { // Use `get_or_insert_with` to initialize `fields` if it's `None` self.fields .get_or_insert_with(Self::default_fields) .push(FieldConfig::Vector(config.into())); self } /// Sets the metadata configuration for the vector similarity search. /// /// This method allows you to specify metadata configurations for vector similarity search using /// `MetadataConfig`. The provided configuration will be added as a new field in the /// builder. /// /// # Arguments /// /// * `config` - The metadata configuration to use. /// /// # Returns /// /// * Returns a mutable reference to `self` for method chaining. pub fn with_metadata(&mut self, config: impl Into<MetadataConfig>) -> &mut Self { // Use `get_or_insert_with` to initialize `fields` if it's `None` self.fields .get_or_insert_with(Self::default_fields) .push(FieldConfig::Metadata(config.into())); self } pub fn default_fields() -> Vec<FieldConfig> { vec![FieldConfig::ID, FieldConfig::Chunk] } } #[cfg(test)] mod tests { use crate::pgvector::fixtures::{PgVectorTestData, TestContext}; use futures_util::TryStreamExt; use std::collections::HashSet; use swiftide_core::{ Persist, Retrieve, document::Document, indexing::{self, EmbedMode, EmbeddedField}, querying::{Query, search_strategies::SimilaritySingleEmbedding, states}, }; use test_case::test_case; #[test_log::test(tokio::test)] async fn test_metadata_filter_with_vector_search() { let test_context = TestContext::setup_with_cfg( vec!["category", "priority"].into(), HashSet::from([EmbeddedField::Combined]), ) .await .expect("Test setup failed"); // Create nodes with different metadata and vectors let nodes = vec![ indexing::TextNode::new("content1") .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) .with_metadata(vec![("category", "A"), ("priority", "1")]), indexing::TextNode::new("content2") .with_vectors([(EmbeddedField::Combined, vec![1.1; 384])]) .with_metadata(vec![("category", "A"), ("priority", "2")]), indexing::TextNode::new("content3") .with_vectors([(EmbeddedField::Combined, vec![1.2; 384])]) .with_metadata(vec![("category", "B"), ("priority", "1")]), ] .into_iter() .map(|node| node.to_owned()) .collect(); // Store all nodes test_context .pgv_storage .batch_store(nodes) .await .try_collect::<Vec<_>>() .await .unwrap(); // Test combined metadata and vector search let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); let search_strategy = SimilaritySingleEmbedding::from_filter("category = \"A\"".to_string()); let result = test_context .pgv_storage .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 2); let contents = result .documents() .iter() .map(Document::content) .collect::<Vec<_>>(); assert!(contents.contains(&"content1")); assert!(contents.contains(&"content2")); // Additional test with priority filter let search_strategy = SimilaritySingleEmbedding::from_filter("priority = \"1\"".to_string()); let result = test_context .pgv_storage .retrieve(&search_strategy, query) .await .unwrap(); assert_eq!(result.documents().len(), 2); let contents = result .documents() .iter() .map(Document::content) .collect::<Vec<_>>(); assert!(contents.contains(&"content1")); assert!(contents.contains(&"content3")); } #[test_log::test(tokio::test)] async fn test_vector_similarity_search_accuracy() { let test_context = TestContext::setup_with_cfg( vec!["category", "priority"].into(), HashSet::from([EmbeddedField::Combined]), ) .await .expect("Test setup failed"); // Create nodes with known vector relationships let base_vector = vec![1.0; 384]; let similar_vector = base_vector.iter().map(|x| x + 0.1).collect::<Vec<_>>(); let dissimilar_vector = vec![-1.0; 384]; let nodes = vec![ indexing::TextNode::new("base_content") .with_vectors([(EmbeddedField::Combined, base_vector)]) .with_metadata(vec![("category", "A"), ("priority", "1")]), indexing::TextNode::new("similar_content") .with_vectors([(EmbeddedField::Combined, similar_vector)]) .with_metadata(vec![("category", "A"), ("priority", "2")]), indexing::TextNode::new("dissimilar_content") .with_vectors([(EmbeddedField::Combined, dissimilar_vector)]) .with_metadata(vec![("category", "B"), ("priority", "1")]), ] .into_iter() .map(|node| node.to_owned()) .collect(); // Store all nodes test_context .pgv_storage .batch_store(nodes) .await .try_collect::<Vec<_>>() .await .unwrap(); // Search with base vector let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); search_strategy.with_top_k(2); let result = test_context .pgv_storage .retrieve(&search_strategy, query) .await .unwrap(); // Verify that similar vectors are retrieved first assert_eq!(result.documents().len(), 2); let contents = result .documents() .iter() .map(Document::content) .collect::<Vec<_>>(); assert!(contents.contains(&"base_content")); assert!(contents.contains(&"similar_content")); } #[test_case( // SingleWithMetadata - No Metadata vec![ PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_no_meta_1", metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)], expected_in_results: true, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_no_meta_2", metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)], expected_in_results: true, } ], HashSet::from([EmbeddedField::Combined]) ; "SingleWithMetadata mode without metadata")] #[test_case( // SingleWithMetadata - With Metadata vec![ PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_with_meta_1", metadata: Some(vec![ ("category", "A"), ("priority", "high") ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)], expected_in_results: true, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_with_meta_2", metadata: Some(vec![ ("category", "B"), ("priority", "low") ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)], expected_in_results: true, } ], HashSet::from([EmbeddedField::Combined]) ; "SingleWithMetadata mode with metadata")] #[test_log::test(tokio::test)] async fn test_persist_nodes( test_cases: Vec<PgVectorTestData<'_>>, vector_fields: HashSet<EmbeddedField>, ) { // Extract all possible metadata fields from test cases let metadata_fields: Vec<&str> = test_cases .iter() .filter_map(|case| case.metadata.as_ref()) .flat_map(|metadata| metadata.iter().map(|(key, _)| key.as_str())) .collect::<std::collections::HashSet<_>>() .into_iter() .collect(); // Initialize test context with all required metadata fields let test_context = TestContext::setup_with_cfg(Some(metadata_fields), vector_fields) .await .expect("Test setup failed"); // Convert test cases to nodes and store them let nodes: Vec<indexing::TextNode> = test_cases.iter().map(PgVectorTestData::to_node).collect(); // Test batch storage let stored_nodes = test_context .pgv_storage .batch_store(nodes.clone()) .await .try_collect::<Vec<_>>() .await .expect("Failed to store nodes"); assert_eq!( stored_nodes.len(), nodes.len(), "All nodes should be stored" ); // Verify storage and retrieval for each test case for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) { // 1. Verify basic node properties assert_eq!( stored_node.chunk, test_case.chunk, "Stored chunk should match" ); assert_eq!( stored_node.embed_mode, test_case.embed_mode, "Embed mode should match" ); // 2. Verify vectors were stored correctly let stored_vectors = stored_node .vectors .as_ref() .expect("Vectors should be present"); assert_eq!( stored_vectors.len(), test_case.vectors.len(), "Vector count should match" ); // 3. Test vector similarity search for (field, vector) in &test_case.vectors { let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vector.clone()); let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); search_strategy.with_top_k(nodes.len() as u64); let result = test_context .pgv_storage .retrieve(&search_strategy, query) .await .expect("Retrieval should succeed"); if test_case.expected_in_results { assert!( result .documents() .iter() .map(Document::content) .collect::<Vec<_>>() .contains(&test_case.chunk), "Document should be found in results for field {field}", ); } } } } } ================================================ FILE: swiftide-integrations/src/pgvector/persist.rs ================================================ //! Storage persistence implementation for vector embeddings. //! //! Implements the [`Persist`] trait for [`PgVector`], providing vector storage capabilities: //! - Database schema initialization and setup //! - Single-node storage operations //! - Optimized batch storage with configurable batch sizes //! //! NOTE: Persisting and retrieving metadata is not supported at the moment. //! //! The implementation ensures thread-safe concurrent access and handles //! connection management automatically. use crate::pgvector::PgVector; use anyhow::{Result, anyhow}; use async_trait::async_trait; use swiftide_core::{ Persist, indexing::{IndexingStream, TextNode}, }; #[async_trait] impl Persist for PgVector { type Input = String; type Output = String; #[tracing::instrument(skip_all)] async fn setup(&self) -> Result<()> { // Get or initialize the connection pool let pool = self.pool_get_or_initialize().await?; if self.sql_stmt_bulk_insert.get().is_none() { let sql = self.generate_unnest_upsert_sql()?; self.sql_stmt_bulk_insert .set(sql) .map_err(|_| anyhow!("SQL bulk store statement is already set"))?; } let mut tx = pool.begin().await?; // Create extension let sql = "CREATE EXTENSION IF NOT EXISTS vector"; sqlx::query(sql).execute(&mut *tx).await?; // Create table let create_table_sql = self.generate_create_table_sql()?; sqlx::query(&create_table_sql).execute(&mut *tx).await?; // Create HNSW index let index_sql = self.create_index_sql()?; sqlx::query(&index_sql).execute(&mut *tx).await?; tx.commit().await?; Ok(()) } #[tracing::instrument(skip_all)] async fn store(&self, node: TextNode) -> Result<TextNode> { let mut nodes = vec![node; 1]; self.store_nodes(&nodes).await?; let node = nodes.swap_remove(0); Ok(node) } #[tracing::instrument(skip_all)] async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> { self.store_nodes(&nodes).await.map(|()| nodes).into() } fn batch_size(&self) -> Option<usize> { Some(self.batch_size) } } #[cfg(test)] mod tests { use crate::pgvector::fixtures::TestContext; use std::collections::HashSet; use swiftide_core::{Persist, indexing::EmbeddedField}; #[test_log::test(tokio::test)] async fn test_persist_setup_no_error_when_table_exists() { let test_context = TestContext::setup_with_cfg( vec!["filter"].into(), HashSet::from([EmbeddedField::Combined]), ) .await .expect("Test setup failed"); test_context .pgv_storage .setup() .await .expect("PgVector setup should not fail when the table already exists"); } } ================================================ FILE: swiftide-integrations/src/pgvector/pgv_table_types.rs ================================================ //! `PostgreSQL` table schema and type conversion utilities for vector storage. //! //! Provides schema configuration and data type conversion functionality: //! - Table schema generation with vector and metadata columns //! - Field configuration for different vector embedding types //! - HNSW index creation for similarity search optimization //! - Bulk data preparation and SQL query generation use crate::pgvector::PgVector; use anyhow::{Result, anyhow}; use pgvector as ExtPgVector; use regex::Regex; use sqlx::PgPool; use sqlx::postgres::PgArguments; use sqlx::postgres::PgPoolOptions; use std::collections::BTreeMap; use swiftide_core::indexing::{EmbeddedField, TextNode}; use tokio::time::sleep; /// Configuration for vector embedding columns in the `PostgreSQL` table. /// /// This struct defines how vector embeddings are stored and managed in the database, /// mapping Swiftide's embedded fields to `PostgreSQL` vector columns. #[derive(Clone, Debug)] pub struct VectorConfig { embedded_field: EmbeddedField, pub field: String, } impl VectorConfig { pub fn new(embedded_field: &EmbeddedField) -> Self { Self { embedded_field: embedded_field.clone(), field: format!( "vector_{}", PgVector::normalize_field_name(&embedded_field.to_string()), ), } } } impl From<EmbeddedField> for VectorConfig { fn from(val: EmbeddedField) -> Self { Self::new(&val) } } /// Configuration for metadata fields in the `PostgreSQL` table. /// /// Handles the mapping and storage of metadata fields, ensuring proper column naming /// and type conversion for `PostgreSQL` compatibility. #[derive(Clone, Debug)] pub struct MetadataConfig { field: String, original_field: String, } impl MetadataConfig { pub fn new<T: Into<String>>(original_field: T) -> Self { let original: String = original_field.into(); Self { field: format!("meta_{}", PgVector::normalize_field_name(&original)), original_field: original, } } } impl<T: AsRef<str>> From<T> for MetadataConfig { fn from(val: T) -> Self { Self::new(val.as_ref()) } } /// Field configuration types supported in the `PostgreSQL` table schema. /// /// Represents different field types that can be configured in the table schema, /// including vector embeddings, metadata, and system fields. #[derive(Clone, Debug)] pub enum FieldConfig { /// `Vector` - Vector embedding field configuration Vector(VectorConfig), /// `Metadata` - Metadata field configuration Metadata(MetadataConfig), /// `Chunk` - Text content storage field Chunk, /// `ID` - Primary key field ID, } impl FieldConfig { pub fn field_name(&self) -> &str { match self { FieldConfig::Vector(config) => &config.field, FieldConfig::Metadata(config) => &config.field, FieldConfig::Chunk => "chunk", FieldConfig::ID => "id", } } } /// Internal structure for managing bulk upsert operations. /// /// Collects and organizes data for efficient bulk insertions and updates, /// grouping related fields for UNNEST-based operations. struct BulkUpsertData<'a> { ids: Vec<sqlx::types::Uuid>, chunks: Vec<&'a str>, metadata_fields: Vec<Vec<serde_json::Value>>, vector_fields: Vec<Vec<ExtPgVector::Vector>>, field_mapping: FieldMapping<'a>, } struct FieldMapping<'a> { metadata_names: Vec<&'a str>, vector_names: Vec<&'a str>, } impl<'a> BulkUpsertData<'a> { fn new(fields: &'a [FieldConfig], size: usize) -> Self { let (metadata_names, vector_names): (Vec<&str>, Vec<&str>) = ( fields .iter() .filter_map(|field| match field { FieldConfig::Metadata(config) => Some(config.field.as_str()), _ => None, }) .collect(), fields .iter() .filter_map(|field| match field { FieldConfig::Vector(config) => Some(config.field.as_str()), _ => None, }) .collect(), ); Self { ids: Vec::with_capacity(size), chunks: Vec::with_capacity(size), metadata_fields: vec![Vec::with_capacity(size); metadata_names.len()], vector_fields: vec![Vec::with_capacity(size); vector_names.len()], field_mapping: FieldMapping { metadata_names, vector_names, }, } } fn get_metadata_index(&self, field: &str) -> Option<usize> { self.field_mapping .metadata_names .iter() .position(|&name| name == field) } fn get_vector_index(&self, field: &str) -> Option<usize> { self.field_mapping .vector_names .iter() .position(|&name| name == field) } } impl PgVector { /// Generates a SQL statement to create a table for storing vector embeddings. /// /// The table will include columns for an ID, chunk data, metadata, and a vector embedding. /// /// # Returns /// /// * The generated SQL statement. /// /// # Errors /// /// * Returns an error if the table name is invalid or if `vector_size` is not configured. pub fn generate_create_table_sql(&self) -> Result<String> { // Validate table_name and field_name (e.g., check against allowed patterns) if !Self::is_valid_identifier(&self.table_name) { return Err(anyhow::anyhow!("Invalid table name")); } let columns: Vec<String> = self .fields .iter() .map(|field| match field { FieldConfig::ID => "id UUID NOT NULL".to_string(), FieldConfig::Chunk => format!("{} TEXT NOT NULL", field.field_name()), FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()), FieldConfig::Vector(_) => { format!("{} VECTOR({})", field.field_name(), self.vector_size) } }) .chain(std::iter::once("PRIMARY KEY (id)".to_string())) .collect(); let sql = format!( "CREATE TABLE IF NOT EXISTS {} (\n {}\n)", self.table_name, columns.join(",\n ") ); Ok(sql) } /// Generates the SQL statement to create an HNSW index on the vector column. /// /// # Errors /// /// Returns an error if: /// - No vector field is found in the table configuration. /// - The table name or field name is invalid. pub fn create_index_sql(&self) -> Result<String> { let index_name = format!("{}_embedding_idx", self.table_name); let vector_field = self .fields .iter() .find(|f| matches!(f, FieldConfig::Vector(_))) .ok_or_else(|| anyhow::anyhow!("No vector field found in configuration"))? .field_name(); // Validate table_name and field_name (e.g., check against allowed patterns) if !Self::is_valid_identifier(&self.table_name) || !Self::is_valid_identifier(&index_name) || !Self::is_valid_identifier(vector_field) { return Err(anyhow::anyhow!("Invalid table or field name")); } Ok(format!( "CREATE INDEX IF NOT EXISTS {} ON {} USING hnsw ({} vector_cosine_ops)", index_name, &self.table_name, vector_field )) } /// Stores a list of nodes in the database using an upsert operation. /// /// # Arguments /// /// * `nodes` - A slice of `TextNode` objects to be stored. /// /// # Returns /// /// * `Result<()>` - `Ok` if all nodes are successfully stored, `Err` otherwise. /// /// # Errors /// /// This function will return an error if: /// - The database connection pool is not established. /// - Any of the SQL queries fail to execute due to schema mismatch, constraint violations, or /// connectivity issues. /// - Committing the transaction fails. pub async fn store_nodes(&self, nodes: &[TextNode]) -> Result<()> { let pool = self.pool_get_or_initialize().await?; let mut tx = pool.begin().await?; let bulk_data = self.prepare_bulk_data(nodes)?; let sql = self .sql_stmt_bulk_insert .get() .ok_or_else(|| anyhow!("SQL bulk insert statement not set"))?; let query = self.bind_bulk_data_to_query(sqlx::query(sql), &bulk_data)?; query .execute(&mut *tx) .await .map_err(|e| anyhow!("Failed to store nodes: {e:?}"))?; tx.commit() .await .map_err(|e| anyhow!("Failed to commit transaction: {e:?}")) } /// Prepares data from nodes into vectors for bulk processing. #[allow(clippy::implicit_clone)] fn prepare_bulk_data<'a>(&'a self, nodes: &'a [TextNode]) -> Result<BulkUpsertData<'a>> { let mut bulk_data = BulkUpsertData::new(&self.fields, nodes.len()); for node in nodes { bulk_data.ids.push(node.id()); bulk_data.chunks.push(node.chunk.as_str()); for field in &self.fields { match field { FieldConfig::Metadata(config) => { let idx = bulk_data .get_metadata_index(config.field.as_str()) .ok_or_else(|| anyhow!("Invalid metadata field"))?; let value = node .metadata .get(&config.original_field) .ok_or_else(|| anyhow!("Missing metadata field"))?; let mut metadata_map = BTreeMap::new(); metadata_map.insert(config.original_field.clone(), value.clone()); bulk_data.metadata_fields[idx].push(serde_json::to_value(metadata_map)?); } FieldConfig::Vector(config) => { let idx = bulk_data .get_vector_index(config.field.as_str()) .ok_or_else(|| anyhow!("Invalid vector field"))?; let data = node .vectors .as_ref() .and_then(|v| v.get(&config.embedded_field)) .map(|v| v.to_vec()) .unwrap_or_default(); bulk_data.vector_fields[idx].push(ExtPgVector::Vector::from(data)); } _ => (), } } } Ok(bulk_data) } /// Generates SQL for UNNEST-based bulk upsert. /// /// # Returns /// /// * `Result<String>` - The generated SQL statement or an error if fields are empty. /// /// # Errors /// /// Returns an error if `self.fields` is empty, as no valid SQL can be generated. pub(crate) fn generate_unnest_upsert_sql(&self) -> Result<String> { if self.fields.is_empty() { return Err(anyhow!("Cannot generate upsert SQL with empty fields")); } let mut columns = Vec::new(); let mut unnest_params = Vec::new(); let mut param_counter = 1; for field in &self.fields { let name = field.field_name(); columns.push(name.to_string()); unnest_params.push(format!( "${param_counter}::{}", match field { FieldConfig::ID => "UUID[]", FieldConfig::Chunk => "TEXT[]", FieldConfig::Metadata(_) => "JSONB[]", FieldConfig::Vector(_) => "VECTOR[]", } )); param_counter += 1; } let update_columns = self .fields .iter() .filter(|field| !matches!(field, FieldConfig::ID)) // Skip ID field in updates .map(|field| { let name = field.field_name(); format!("{name} = EXCLUDED.{name}") }) .collect::<Vec<_>>() .join(", "); Ok(format!( r" INSERT INTO {} ({}) SELECT {} FROM UNNEST({}) AS t({}) ON CONFLICT (id) DO UPDATE SET {}", self.table_name, columns.join(", "), columns.join(", "), unnest_params.join(", "), columns.join(", "), update_columns )) } /// Binds bulk data to the SQL query, ensuring data arrays are matched to corresponding fields. /// /// # Errors /// /// Returns an error if any metadata or vector field is missing from the bulk data. #[allow(clippy::implicit_clone)] fn bind_bulk_data_to_query<'a>( &self, mut query: sqlx::query::Query<'a, sqlx::Postgres, PgArguments>, bulk_data: &'a BulkUpsertData, ) -> Result<sqlx::query::Query<'a, sqlx::Postgres, PgArguments>> { for field in &self.fields { query = match field { FieldConfig::ID => query.bind(&bulk_data.ids), FieldConfig::Chunk => query.bind(&bulk_data.chunks), FieldConfig::Vector(config) => { let idx = bulk_data .get_vector_index(config.field.as_str()) .ok_or_else(|| { anyhow!("Vector field {} not found in bulk data", config.field) })?; query.bind(&bulk_data.vector_fields[idx]) } FieldConfig::Metadata(config) => { let idx = bulk_data .get_metadata_index(config.field.as_str()) .ok_or_else(|| { anyhow!("Metadata field {} not found in bulk data", config.field) })?; query.bind(&bulk_data.metadata_fields[idx]) } }; } Ok(query) } /// Retrieves the name of the vector column configured in the schema. /// /// # Returns /// * `Ok(String)` - The name of the vector column if exactly one is configured. /// # Errors /// * `Error::NoEmbedding` - If no vector field is configured in the schema. /// * `Error::MultipleEmbeddings` - If multiple vector fields are configured in the schema. pub fn get_vector_column_name(&self) -> Result<String> { let vector_fields: Vec<_> = self .fields .iter() .filter(|field| matches!(field, FieldConfig::Vector(_))) .collect(); match vector_fields.as_slice() { [field] => Ok(field.field_name().to_string()), [] => Err(anyhow!("No vector field configured in schema")), _ => Err(anyhow!( "Search strategy for multiple vector fields in the schema is not yet implemented" )), } } } impl PgVector { pub fn normalize_field_name(field: &str) -> String { // Define the special characters as an array let special_chars: [char; 4] = ['(', '[', '{', '<']; // First split by special characters and take the first part let base_text = field .split(|c| special_chars.contains(&c)) .next() .unwrap_or(field) .trim(); // Split by whitespace, take up to 3 words, convert to lowercase let normalized = base_text .split_whitespace() .take(3) .collect::<Vec<&str>>() .join("_") .to_lowercase(); // Ensure the result only contains alphanumeric chars and underscores normalized .chars() .filter(|c| c.is_alphanumeric() || *c == '_') .collect() } pub(crate) fn is_valid_identifier(identifier: &str) -> bool { // PostgreSQL identifier rules: // 1. Must start with a letter (a-z) or underscore // 2. Subsequent characters can be letters, underscores, digits (0-9), or dollar signs // 3. Maximum length is 63 bytes // 4. Cannot be a reserved keyword // Check length if identifier.is_empty() || identifier.len() > 63 { return false; } // Use a regular expression to check the pattern let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_$]*$").unwrap(); if !identifier_regex.is_match(identifier) { return false; } // Check if it's not a reserved keyword !Self::is_reserved_keyword(identifier) } pub(crate) fn is_reserved_keyword(word: &str) -> bool { // This list is not exhaustive. You may want to expand it based on // the PostgreSQL version you're using. const RESERVED_KEYWORDS: &[&str] = &[ "SELECT", "FROM", "WHERE", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "TABLE", "INDEX", "ALTER", "ADD", "COLUMN", "AND", "OR", "NOT", "NULL", "TRUE", "FALSE", // Add more keywords as needed ]; RESERVED_KEYWORDS.contains(&word.to_uppercase().as_str()) } } impl PgVector { async fn create_pool(&self) -> Result<PgPool> { let pool_options = PgPoolOptions::new().max_connections(self.db_max_connections); for attempt in 1..=self.db_max_retry { match pool_options.clone().connect(self.db_url.as_ref()).await { Ok(pool) => { tracing::info!("Successfully established database connection"); return Ok(pool); } Err(err) if attempt < self.db_max_retry => { tracing::warn!( error = %err, attempt = attempt, max_retries = self.db_max_retry, "Database connection attempt failed, retrying..." ); sleep(self.db_conn_retry_delay).await; } Err(err) => { return Err(anyhow!(err).context("Failed to establish database connection")); } } } Err(anyhow!( "Max connection retries ({}) exceeded", self.db_max_retry )) } /// Returns a reference to the `PgPool` if it is already initialized, /// or creates and initializes it if it is not. /// /// # Errors /// This function will return an error if pool creation fails. pub async fn pool_get_or_initialize(&self) -> Result<&PgPool> { if let Some(pool) = self.connection_pool.get() { return Ok(pool); } let pool = self.create_pool().await?; self.connection_pool .set(pool) .map_err(|_| anyhow!("Pool already initialized"))?; // Re-check if the pool was set successfully, otherwise return an error self.connection_pool .get() .ok_or_else(|| anyhow!("Failed to retrieve connection pool after setting it")) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_valid_identifiers() { assert!(PgVector::is_valid_identifier("valid_name")); assert!(PgVector::is_valid_identifier("_valid_name")); assert!(PgVector::is_valid_identifier("valid_name_123")); assert!(PgVector::is_valid_identifier("validName")); } #[test] fn test_invalid_identifiers() { assert!(!PgVector::is_valid_identifier("")); // Empty string assert!(!PgVector::is_valid_identifier(&"a".repeat(64))); // Too long assert!(!PgVector::is_valid_identifier("123_invalid")); // Starts with a number assert!(!PgVector::is_valid_identifier("invalid-name")); // Contains hyphen assert!(!PgVector::is_valid_identifier("select")); // Reserved keyword } } ================================================ FILE: swiftide-integrations/src/pgvector/retrieve.rs ================================================ use crate::pgvector::{FieldConfig, PgVector, PgVectorBuilder}; use anyhow::{Result, anyhow}; use async_trait::async_trait; use pgvector::Vector; use sqlx::{Column, Row, prelude::FromRow, types::Uuid}; use std::fmt::Write as _; use swiftide_core::{ Retrieve, document::Document, indexing::Metadata, querying::{ Query, search_strategies::{CustomStrategy, SimilaritySingleEmbedding}, states, }, }; #[allow(dead_code)] #[derive(Debug, Clone)] struct VectorSearchResult { id: Uuid, chunk: String, metadata: Metadata, } impl From<VectorSearchResult> for Document { fn from(val: VectorSearchResult) -> Self { Document::new(val.chunk, Some(val.metadata)) } } impl FromRow<'_, sqlx::postgres::PgRow> for VectorSearchResult { fn from_row(row: &sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> { let mut metadata = Metadata::default(); // Metadata fields are stored each as prefixed meta_ fields. Perhaps we should add a single // metadata field instead of multiple fields. for column in row.columns() { if column.name().starts_with("meta_") { row.try_get::<serde_json::Value, _>(column.name())? .as_object() .and_then(|object| { object.keys().collect::<Vec<_>>().first().map(|key| { metadata.insert( key.to_owned(), object.get(key.as_str()).expect("infallible").clone(), ); }) }); } } Ok(VectorSearchResult { id: row.try_get("id")?, chunk: row.try_get("chunk")?, metadata, }) } } #[allow(clippy::redundant_closure_for_method_calls)] #[async_trait] impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector { #[tracing::instrument] async fn retrieve( &self, search_strategy: &SimilaritySingleEmbedding<String>, query_state: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { let embedding = if let Some(embedding) = query_state.embedding.as_ref() { Vector::from(embedding.clone()) } else { return Err(anyhow::Error::msg("Missing embedding in query state")); }; let vector_column_name = self.get_vector_column_name()?; let pool = self.pool_get_or_initialize().await?; let default_columns: Vec<_> = PgVectorBuilder::default_fields() .iter() .map(|f| f.field_name().to_string()) .chain( self.fields .iter() .filter(|f| matches!(f, FieldConfig::Metadata(_))) .map(|f| f.field_name().to_string()), ) .collect(); // Start building the SQL query let mut sql = format!( "SELECT {} FROM {}", default_columns.join(", "), self.table_name ); if let Some(filter) = search_strategy.filter() { let filter_parts: Vec<&str> = filter.split('=').collect(); if filter_parts.len() == 2 { let key = filter_parts[0].trim(); let value = filter_parts[1].trim().trim_matches('"'); tracing::debug!( "Filter being applied: key = {:#?}, value = {:#?}", key, value ); let sql_filter = format!( " WHERE meta_{}->>'{}' = '{}'", PgVector::normalize_field_name(key), key, value ); sql.push_str(&sql_filter); } else { return Err(anyhow!("Invalid filter format")); } } // Add the ORDER BY clause for vector similarity search write!(sql, " ORDER BY {vector_column_name} <=> $1 LIMIT $2")?; tracing::debug!("Running retrieve with SQL: {}", sql); let top_k = i32::try_from(search_strategy.top_k()) .map_err(|_| anyhow!("Failed to convert top_k to i32"))?; let data: Vec<VectorSearchResult> = sqlx::query_as(&sql) .bind(embedding) .bind(top_k) .fetch_all(pool) .await?; let docs = data.into_iter().map(Into::into).collect(); Ok(query_state.retrieved_documents(docs)) } } #[async_trait] impl Retrieve<SimilaritySingleEmbedding> for PgVector { async fn retrieve( &self, search_strategy: &SimilaritySingleEmbedding, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { Retrieve::<SimilaritySingleEmbedding<String>>::retrieve( self, &search_strategy.into_concrete_filter::<String>(), query, ) .await } } #[async_trait] impl Retrieve<CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>> for PgVector { async fn retrieve( &self, search_strategy: &CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { // Get the database pool let pool = self.get_pool().await?; // Build the custom query using both strategy and query state let mut query_builder = search_strategy.build_query(&query).await?; // Execute the query using the builder's built-in methods let results = query_builder .build_query_as::<VectorSearchResult>() // Convert to a typed query .fetch_all(pool) // Execute and get all results .await .map_err(|e| anyhow!("Failed to execute search query: {e}"))?; // Transform results into documents let documents = results.into_iter().map(Into::into).collect(); // Update query state with retrieved documents Ok(query.retrieved_documents(documents)) } } #[cfg(test)] mod tests { use crate::pgvector::fixtures::TestContext; use futures_util::TryStreamExt; use std::collections::HashSet; use swiftide_core::{Persist, indexing, indexing::EmbeddedField}; use swiftide_core::{ Retrieve, querying::{Query, search_strategies::SimilaritySingleEmbedding, states}, }; #[test_log::test(tokio::test)] async fn test_retrieve_multiple_docs_and_filter() { let test_context = TestContext::setup_with_cfg( vec!["filter"].into(), HashSet::from([EmbeddedField::Combined]), ) .await .expect("Test setup failed"); let nodes = vec![ indexing::TextNode::new("test_query1").with_metadata(("filter", "true")), indexing::TextNode::new("test_query2").with_metadata(("filter", "true")), indexing::TextNode::new("test_query3").with_metadata(("filter", "false")), ] .into_iter() .map(|node| { node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); node.to_owned() }) .collect(); test_context .pgv_storage .batch_store(nodes) .await .try_collect::<Vec<_>>() .await .unwrap(); let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); let search_strategy = SimilaritySingleEmbedding::<()>::default(); let result = test_context .pgv_storage .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 3); let search_strategy = SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string()); let result = test_context .pgv_storage .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 2); let search_strategy = SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string()); let result = test_context .pgv_storage .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 0); } #[test_log::test(tokio::test)] async fn test_retrieve_docs_with_metadata() { let test_context = TestContext::setup_with_cfg( vec!["other", "text"].into(), HashSet::from([EmbeddedField::Combined]), ) .await .expect("Test setup failed"); let nodes = vec![ indexing::TextNode::new("test_query1") .with_metadata([ ("other", serde_json::Value::from(10)), ("text", serde_json::Value::from("some text")), ]) .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) .to_owned(), ]; test_context .pgv_storage .batch_store(nodes) .await .try_collect::<Vec<_>>() .await .unwrap(); let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); let search_strategy = SimilaritySingleEmbedding::<()>::default(); let result = test_context .pgv_storage .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 1); let doc = result.documents().first().unwrap(); assert_eq!( doc.metadata().get("other"), Some(&serde_json::Value::from(10)) ); assert_eq!( doc.metadata().get("text"), Some(&serde_json::Value::from("some text")) ); } } ================================================ FILE: swiftide-integrations/src/qdrant/indexing_node.rs ================================================ //! This module provides functionality to convert an `Node` into a `qdrant::PointStruct`. //! The conversion is essential for storing data in the Qdrant vector database, which is used //! for efficient vector similarity search. The module handles metadata augmentation and ensures //! data compatibility with Qdrant's required format. use anyhow::{Result, bail}; use std::{ collections::{HashMap, HashSet}, string::ToString, }; use qdrant_client::{ Payload, qdrant::{self, Value}, }; use swiftide_core::{Embedding, SparseEmbedding, indexing::EmbeddedField}; use super::NodeWithVectors; /// Implements the `TryInto` trait to convert an `NodeWithVectors` into a `qdrant::PointStruct`. /// This conversion is necessary for storing the node in the Qdrant vector database. impl TryInto<qdrant::PointStruct> for NodeWithVectors<'_> { type Error = anyhow::Error; /// Converts the `Node` into a `qdrant::PointStruct`. /// /// # Errors /// /// Returns an error if the vector is not set in the `Node`. /// /// # Returns /// /// A `Result` which is `Ok` if the conversion is successful, containing the /// `qdrant::PointStruct`. If the conversion fails, it returns an `anyhow::Error`. fn try_into(self) -> Result<qdrant::PointStruct> { let node = self.node; // Calculate a unique identifier for the node. let id = node.id(); // Extend the metadata with additional information. // TODO: The node is already cloned in the `NodeWithVectors` constructor. // Then additional data is added to the metadata, including the full chunk // Data is then taken as ref and reassigned. Seems like a lot of needless allocations // Create a payload compatible with Qdrant's API. let mut payload: Payload = node .metadata .iter() .map(|(k, v)| (k.clone(), Value::from(v.clone()))) .collect::<HashMap<String, Value>>() .into(); payload.insert("path", node.path.to_string_lossy().to_string()); payload.insert("content", node.chunk.clone()); payload.insert( "last_updated_at", Value::from(chrono::Utc::now().to_rfc3339()), ); let Some(vectors) = node.vectors.clone() else { bail!("Node without vectors") }; let vectors = try_create_vectors(&self.vector_fields, vectors, node.sparse_vectors.clone())?; // Construct the `qdrant::PointStruct` and return it. Ok(qdrant::PointStruct::new(id.to_string(), vectors, payload)) } } fn try_create_vectors( vector_fields: &HashSet<&EmbeddedField>, vectors: HashMap<EmbeddedField, Embedding>, sparse_vectors: Option<HashMap<EmbeddedField, SparseEmbedding>>, ) -> Result<qdrant::Vectors> { if vectors.is_empty() { bail!("Node with empty vectors") } else if vectors.len() == 1 && sparse_vectors.is_none() { let Some(vector) = vectors.into_values().next() else { bail!("Node has no vector entry") }; return Ok(vector.into()); } let mut qdrant_vectors = qdrant::NamedVectors::default(); for (field, vector) in vectors { if !vector_fields.contains(&field) { continue; } qdrant_vectors = qdrant_vectors.add_vector(field.to_string(), vector); } if let Some(sparse_vectors) = sparse_vectors { for (field, sparse_vector) in sparse_vectors { if !vector_fields.contains(&field) { continue; } qdrant_vectors = qdrant_vectors.add_vector( format!("{field}_sparse"), qdrant::Vector::new_sparse( sparse_vector.indices.into_iter().collect::<Vec<_>>(), sparse_vector.values, ), ); } } Ok(qdrant_vectors.into()) } #[cfg(test)] mod tests { use std::collections::{HashMap, HashSet}; use qdrant_client::qdrant::PointStruct; use swiftide_core::indexing::{EmbeddedField, TextNode}; use test_case::test_case; use crate::qdrant::indexing_node::NodeWithVectors; use pretty_assertions::assert_eq; static EXPECTED_UUID: &str = "d42d252d-671d-37ef-a157-8e85d0710610"; #[test_case( TextNode::builder() .path("/path") .chunk("data") .vectors([(EmbeddedField::Chunk, vec![1.0])]) .metadata([("m1", "mv1")]) .embed_mode(swiftide_core::indexing::EmbedMode::SingleWithMetadata) .build().unwrap() , HashSet::from([EmbeddedField::Combined]), PointStruct::new(EXPECTED_UUID, vec![1.0], HashMap::from([ ("content", "data".into()), ("path", "/path".into()), ("m1", "mv1".into())]) ); "Node with single vector creates struct with unnamed vector" )] #[test_case( TextNode::builder() .path("/path") .chunk("data") .vectors([ (EmbeddedField::Chunk, vec![1.0]), (EmbeddedField::Metadata("m1".into()), vec![2.0]) ]) .metadata([("m1", "mv1")]) .embed_mode(swiftide_core::indexing::EmbedMode::PerField) .build().unwrap(), HashSet::from([EmbeddedField::Chunk, EmbeddedField::Metadata("m1".into())]), PointStruct::new(EXPECTED_UUID, HashMap::from([ ("Chunk".to_string(), vec![1.0]), ("Metadata: m1".to_string(), vec![2.0]) ]), HashMap::from([ ("content", "data".into()), ("path", "/path".into()), ("m1", "mv1".into())]) ); "Node with multiple vectors creates struct with named vectors" )] #[test_case( TextNode::builder() .path("/path") .chunk("data") .vectors([ (EmbeddedField::Chunk, vec![1.0]), (EmbeddedField::Combined, vec![1.0]), (EmbeddedField::Metadata("m1".into()), vec![1.0]), (EmbeddedField::Metadata("m2".into()), vec![2.0]) ]) .metadata([("m1", "mv1"), ("m2", "mv2")]) .embed_mode(swiftide_core::indexing::EmbedMode::Both) .build().unwrap(), HashSet::from([EmbeddedField::Combined]), PointStruct::new(EXPECTED_UUID, HashMap::from([ ("Combined".to_string(), vec![1.0]), ]), HashMap::from([ ("content", "data".into()), ("path", "/path".into()), ("m1", "mv1".into()), ("m2", "mv2".into())]) ); "Storing only `Combined` vector. Skipping other vectors." )] #[allow(clippy::needless_pass_by_value)] fn try_into_point_struct_test( node: TextNode, vector_fields: HashSet<EmbeddedField>, mut expected_point: PointStruct, ) { let node = NodeWithVectors::new(&node, vector_fields.iter().collect()); let point: PointStruct = node.try_into().expect("Can create PointStruct"); // patch last_update_at field to avoid test failure because of time difference let last_updated_at_key = "last_updated_at"; let last_updated_at = point .payload .get(last_updated_at_key) .expect("Has autogenerated `last_updated_at` field."); expected_point .payload .insert(last_updated_at_key.into(), last_updated_at.clone()); assert_eq!(point.id, expected_point.id); assert_eq!(point.payload, expected_point.payload); assert_eq!(point.vectors, expected_point.vectors); } } ================================================ FILE: swiftide-integrations/src/qdrant/mod.rs ================================================ //! This module provides integration with the Qdrant vector database. //! It includes functionalities to interact with Qdrant, such as creating and managing vector //! collections, storing data, and ensuring proper indexing for efficient searches. //! //! Qdrant can be used both in `indexing::Pipeline` and `query::Pipeline` mod indexing_node; mod persist; mod retrieve; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use anyhow::{Context as _, Result, bail}; use derive_builder::Builder; pub use qdrant_client; use qdrant_client::qdrant::{self, SparseVectorParamsBuilder, SparseVectorsConfigBuilder}; use swiftide_core::indexing::{EmbeddedField, TextNode}; const DEFAULT_COLLECTION_NAME: &str = "swiftide"; const DEFAULT_QDRANT_URL: &str = "http://localhost:6334"; const DEFAULT_BATCH_SIZE: usize = 50; /// A struct representing a Qdrant client with configuration options. /// /// This struct is used to interact with the Qdrant vector database, providing methods to create and /// manage vector collections, store data, and ensure proper indexing for efficient searches. /// /// Can be cloned with relative low cost as the client is shared. #[derive(Builder, Clone)] #[builder( pattern = "owned", setter(strip_option), build_fn(error = "anyhow::Error") )] pub struct Qdrant { /// The Qdrant client used to interact with the Qdrant vector database. /// /// By default the client will be build from `QDRANT_URL` and option `QDRANT_API_KEY`. /// It will fall back to `http://localhost:6334` if `QDRANT_URL` is not set. #[builder(setter(into), default = "self.default_client()?")] #[allow(clippy::missing_fields_in_debug)] client: Arc<qdrant_client::Qdrant>, /// The name of the collection to be used in Qdrant. Defaults to "swiftide". #[builder(default = "DEFAULT_COLLECTION_NAME.to_string()")] #[builder(setter(into))] collection_name: String, /// The default size of the vectors to be stored in the collection. vector_size: u64, #[builder(default = "Distance::Cosine")] /// The default distance of the vectors to be stored in the collection vector_distance: Distance, /// The batch size for operations. Optional. #[builder(default = "Some(DEFAULT_BATCH_SIZE)")] batch_size: Option<usize>, #[builder(private, default = "Self::default_vectors()")] pub(crate) vectors: HashMap<EmbeddedField, VectorConfig>, #[builder(private, default)] pub(crate) sparse_vectors: HashMap<EmbeddedField, SparseVectorConfig>, } impl Qdrant { /// Returns a new `QdrantBuilder` for constructing a `Qdrant` instance. pub fn builder() -> QdrantBuilder { QdrantBuilder::default() } /// Tries to create a `QdrantBuilder` from a given URL. Will use the api key in `QDRANT_API_KEY` /// if present. /// /// Returns /// /// # Arguments /// /// * `url` - A string slice that holds the URL for the Qdrant client. /// /// # Returns /// /// A `Result` containing the `QdrantBuilder` if successful, or an error otherwise. /// /// # Errors /// /// Errors if client fails build pub fn try_from_url(url: impl AsRef<str>) -> Result<QdrantBuilder> { Ok(QdrantBuilder::default().client( qdrant_client::Qdrant::from_url(url.as_ref()) .api_key(std::env::var("QDRANT_API_KEY")) .build()?, )) } /// Creates an index in the Qdrant collection if it does not already exist. /// /// This method checks if the specified collection exists in Qdrant. If it does not exist, it /// creates a new collection with the specified vector size and cosine distance metric. /// /// # Returns /// /// A `Result` indicating success or failure. /// /// # Errors /// /// Errors if client fails build pub async fn create_index_if_not_exists(&self) -> Result<()> { let collection_name = &self.collection_name; tracing::debug!("Checking if collection {collection_name} exists"); if self.client.collection_exists(collection_name).await? { tracing::warn!( "Collection {collection_name} exists, skipping collection creation; if vector configurations have not changed, you can ignore this message" ); return Ok(()); } let vectors_config = self.create_vectors_config()?; tracing::debug!(?vectors_config, "Adding vectors config"); let mut collection = qdrant::CreateCollectionBuilder::new(collection_name).vectors_config(vectors_config); if let Some(sparse_vectors_config) = self.create_sparse_vectors_config() { tracing::debug!(?sparse_vectors_config, "Adding sparse vectors config"); collection = collection.sparse_vectors_config(sparse_vectors_config); } tracing::info!("Creating collection {collection_name}"); self.client.create_collection(collection).await?; Ok(()) } fn create_vectors_config(&self) -> Result<qdrant_client::qdrant::vectors_config::Config> { if self.vectors.is_empty() { bail!("No configured vectors"); } else if self.vectors.len() == 1 && self.sparse_vectors.is_empty() { let config = self .vectors .values() .next() .context("Has one vector config")?; let vector_params = self.create_vector_params(config); return Ok(qdrant::vectors_config::Config::Params(vector_params)); } let mut map = HashMap::<String, qdrant::VectorParams>::default(); for (embedded_field, config) in &self.vectors { let vector_name = embedded_field.to_string(); let vector_params = self.create_vector_params(config); map.insert(vector_name, vector_params); } Ok(qdrant::vectors_config::Config::ParamsMap( qdrant::VectorParamsMap { map }, )) } fn create_sparse_vectors_config(&self) -> Option<qdrant::SparseVectorConfig> { if self.sparse_vectors.is_empty() { return None; } let mut sparse_vectors_config = SparseVectorsConfigBuilder::default(); for embedded_field in self.sparse_vectors.keys() { let vector_name = format!("{embedded_field}_sparse"); let vector_params = SparseVectorParamsBuilder::default(); sparse_vectors_config.add_named_vector_params(vector_name, vector_params); } Some(sparse_vectors_config.into()) } fn create_vector_params(&self, config: &VectorConfig) -> qdrant::VectorParams { let size = config.vector_size.unwrap_or(self.vector_size); let distance = config.distance.unwrap_or(self.vector_distance); tracing::debug!( "Creating vector params: size={}, distance={:?}", size, distance ); qdrant::VectorParamsBuilder::new(size, distance).build() } /// Returns the inner client for custom operations pub fn client(&self) -> &Arc<qdrant_client::Qdrant> { &self.client } } impl QdrantBuilder { #[allow(clippy::unused_self)] fn default_client(&self) -> Result<Arc<qdrant_client::Qdrant>> { let client = qdrant_client::Qdrant::from_url( &std::env::var("QDRANT_URL").unwrap_or(DEFAULT_QDRANT_URL.to_string()), ) .api_key(std::env::var("QDRANT_API_KEY")) .build() .context("Could not build default qdrant client")?; Ok(Arc::new(client)) } /// Configures a dense vector on the collection /// /// When not configured Pipeline by default configures vector only for /// [`EmbeddedField::Combined`] Default config is enough when /// `indexing::Pipeline::with_embed_mode` is not set or when the value is set to /// [`swiftide_core::indexing::EmbedMode::SingleWithMetadata`]. #[must_use] pub fn with_vector(mut self, vector: impl Into<VectorConfig>) -> QdrantBuilder { if self.vectors.is_none() { self = self.vectors(HashMap::default()); } let vector = vector.into(); if let Some(vectors) = self.vectors.as_mut() && let Some(overridden_vector) = vectors.insert(vector.embedded_field.clone(), vector) { tracing::warn!( "Overriding named vector config: {}", overridden_vector.embedded_field ); } self } /// Configures a sparse vector on the collection #[must_use] pub fn with_sparse_vector(mut self, vector: impl Into<SparseVectorConfig>) -> QdrantBuilder { if self.sparse_vectors.is_none() { self = self.sparse_vectors(HashMap::default()); } let vector = vector.into(); if let Some(vectors) = self.sparse_vectors.as_mut() && let Some(overridden_vector) = vectors.insert(vector.embedded_field.clone(), vector) { tracing::warn!( "Overriding named vector config: {}", overridden_vector.embedded_field ); } self } fn default_vectors() -> HashMap<EmbeddedField, VectorConfig> { HashMap::from([(EmbeddedField::default(), VectorConfig::default())]) } } #[allow(clippy::missing_fields_in_debug)] impl std::fmt::Debug for Qdrant { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Qdrant") .field("collection_name", &self.collection_name) .field("vector_size", &self.vector_size) .field("batch_size", &self.batch_size) .finish() } } /// Vector config /// /// See also [`QdrantBuilder::with_vector`] #[derive(Clone, Builder, Default)] pub struct VectorConfig { /// A type of the embeddable of the stored vector. #[builder(default)] pub(super) embedded_field: EmbeddedField, /// A size of the vector to be stored in the collection. /// /// Overrides default set in [`QdrantBuilder::vector_size`] #[builder(setter(into, strip_option), default)] vector_size: Option<u64>, /// A distance of the vector to be stored in the collection. /// /// Overrides default set in [`QdrantBuilder::vector_distance`] #[builder(setter(into, strip_option), default)] distance: Option<qdrant::Distance>, } impl VectorConfig { pub fn builder() -> VectorConfigBuilder { VectorConfigBuilder::default() } } impl From<EmbeddedField> for VectorConfig { fn from(value: EmbeddedField) -> Self { Self { embedded_field: value, ..Default::default() } } } /// Sparse Vector config #[derive(Clone, Builder, Default)] pub struct SparseVectorConfig { embedded_field: EmbeddedField, } impl From<EmbeddedField> for SparseVectorConfig { fn from(value: EmbeddedField) -> Self { Self { embedded_field: value, } } } pub type Distance = qdrant::Distance; /// Utility struct combining `TextNode` with `EmbeddedField`s of configured _Qdrant_ vectors. struct NodeWithVectors<'a> { node: &'a TextNode, vector_fields: HashSet<&'a EmbeddedField>, } impl<'a> NodeWithVectors<'a> { pub fn new(node: &'a TextNode, vector_fields: HashSet<&'a EmbeddedField>) -> Self { Self { node, vector_fields, } } } ================================================ FILE: swiftide-integrations/src/qdrant/persist.rs ================================================ //! This module provides an implementation of the `Storage` trait for the `Qdrant` struct. //! It includes methods for setting up the storage, storing a single node, and storing a batch of //! nodes. This integration allows the Swiftide project to use Qdrant as a storage backend. use std::collections::HashSet; use swiftide_core::{ indexing::{EmbeddedField, IndexingStream, Persist, TextNode}, prelude::*, }; use qdrant_client::qdrant::UpsertPointsBuilder; use super::{NodeWithVectors, Qdrant}; #[async_trait] impl Persist for Qdrant { type Input = String; type Output = String; /// Returns the batch size for the Qdrant storage. /// /// # Returns /// /// An `Option<usize>` representing the batch size if set, otherwise `None`. fn batch_size(&self) -> Option<usize> { self.batch_size } /// Sets up the Qdrant storage by creating the necessary index if it does not exist. /// /// # Returns /// /// A `Result<()>` which is `Ok` if the setup is successful, otherwise an error. /// /// # Errors /// /// This function will return an error if the index creation fails. #[tracing::instrument(skip_all, err)] async fn setup(&self) -> Result<()> { tracing::debug!("Setting up Qdrant storage"); self.create_index_if_not_exists().await } /// Stores a single indexing node in the Qdrant storage. /// /// WARN: If running debug builds, the store is blocking and will impact performance /// /// # Parameters /// /// - `node`: The `TextNode` to be stored. /// /// # Returns /// /// A `Result<()>` which is `Ok` if the storage is successful, otherwise an error. /// /// # Errors /// /// This function will return an error if the node conversion or storage operation fails. #[tracing::instrument(skip_all, err, name = "storage.qdrant.store")] async fn store(&self, node: TextNode) -> Result<TextNode> { let node_with_vectors = NodeWithVectors::new(&node, self.vector_fields()); let point = node_with_vectors.try_into()?; tracing::debug!("Storing node"); self.client .upsert_points( UpsertPointsBuilder::new(self.collection_name.clone(), vec![point]) .wait(cfg!(debug_assertions)), ) .await?; Ok(node) } /// Stores a batch of indexing nodes in the Qdrant storage. /// /// # Parameters /// /// - `nodes`: A vector of `TextNode` to be stored. /// /// # Returns /// /// A `Result<()>` which is `Ok` if the storage is successful, otherwise an error. /// /// # Errors /// /// This function will return an error if any node conversion or storage operation fails. #[tracing::instrument(skip_all, name = "storage.qdrant.batch_store")] async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> { let points = nodes .iter() .map(|node| NodeWithVectors::new(node, self.vector_fields())) .map(NodeWithVectors::try_into) .collect::<Result<Vec<_>>>(); let Ok(points) = points else { return vec![Err(points.unwrap_err())].into(); }; tracing::debug!("Storing batch of {} nodes", points.len()); match self .client .upsert_points( UpsertPointsBuilder::new(self.collection_name.clone(), points) .wait(cfg!(debug_assertions)), ) .await { Ok(_) => IndexingStream::iter(nodes.into_iter().map(Ok)), Err(err) => vec![Err(err.into())].into(), } } } impl Qdrant { fn vector_fields(&self) -> HashSet<&EmbeddedField> { self.vectors.keys().collect::<HashSet<_>>() } } ================================================ FILE: swiftide-integrations/src/qdrant/retrieve.rs ================================================ use qdrant_client::qdrant::{self, PrefetchQueryBuilder, ScoredPoint, SearchPointsBuilder}; use swiftide_core::{ Retrieve, document::Document, indexing::{EmbeddedField, Metadata}, prelude::{Result, *}, querying::{ Query, search_strategies::{HybridSearch, SimilaritySingleEmbedding}, states, }, }; use super::Qdrant; /// Implement the `Retrieve` trait for `SimilaritySingleEmbedding` search strategy. /// /// Can be used in the query pipeline to retrieve documents from Qdrant. /// /// Supports filters via the `qdrant_client::qdrant::Filter` type. #[async_trait] impl Retrieve<SimilaritySingleEmbedding<qdrant::Filter>> for Qdrant { #[tracing::instrument] async fn retrieve( &self, search_strategy: &SimilaritySingleEmbedding<qdrant::Filter>, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { let Some(embedding) = &query.embedding else { anyhow::bail!("No embedding for query") }; let mut query_builder = SearchPointsBuilder::new( &self.collection_name, embedding.to_owned(), search_strategy.top_k(), ) .with_payload(true); if let Some(filter) = &search_strategy.filter() { query_builder = query_builder.filter(filter.to_owned()); } if self.vectors.len() > 1 || !self.sparse_vectors.is_empty() { // TODO: Make this configurable // It will break if there are multiple vectors and no combined vector query_builder = query_builder.vector_name(EmbeddedField::Combined.field_name()); } let result = self .client .search_points(query_builder.build()) .await .context("Failed to retrieve from qdrant")? .result; let documents = result .into_iter() .map(scored_point_into_document) .collect::<Result<Vec<_>>>()?; Ok(query.retrieved_documents(documents)) } } /// Ensures that the `SimilaritySingleEmbedding` search strategy can be used when no filter is set. #[async_trait] impl Retrieve<SimilaritySingleEmbedding> for Qdrant { async fn retrieve( &self, search_strategy: &SimilaritySingleEmbedding, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { Retrieve::<SimilaritySingleEmbedding<qdrant::Filter>>::retrieve( self, &search_strategy.into_concrete_filter::<qdrant::Filter>(), query, ) .await } } /// Implement the `Retrieve` trait for `HybridSearch` search strategy. /// /// Can be used in the query pipeline to retrieve documents from Qdrant. /// /// Expects both a dense and sparse embedding to be set on the query. #[async_trait] impl Retrieve<HybridSearch<qdrant::Filter>> for Qdrant { #[tracing::instrument] async fn retrieve( &self, search_strategy: &HybridSearch<qdrant::Filter>, query: Query<states::Pending>, ) -> Result<Query<states::Retrieved>> { let Some(dense) = &query.embedding else { anyhow::bail!("No embedding for query") }; let Some(sparse) = &query.sparse_embedding else { anyhow::bail!("No sparse embedding for query") }; let mut sparse_prefetch = PrefetchQueryBuilder::default() .query(qdrant::Query::new_nearest(qdrant::VectorInput::new_sparse( sparse.indices.clone(), sparse.values.clone(), ))) .using(search_strategy.sparse_vector_field().sparse_field_name()) .limit(search_strategy.top_n()); let mut dense_prefetch = PrefetchQueryBuilder::default() .query(qdrant::Query::new_nearest(dense.clone())) .using(search_strategy.dense_vector_field().field_name()) .limit(search_strategy.top_n()); if let Some(filter) = search_strategy.filter() { sparse_prefetch = sparse_prefetch.filter(filter.clone()); dense_prefetch = dense_prefetch.filter(filter.clone()); } let query_points = qdrant::QueryPointsBuilder::new(&self.collection_name) .with_payload(true) .add_prefetch(sparse_prefetch) .add_prefetch(dense_prefetch) .query(qdrant::Query::new_fusion(qdrant::Fusion::Rrf)) .limit(search_strategy.top_k()); // NOTE: Potential improvement to consume the vectors instead of cloning let result = self.client.query(query_points).await?.result; let documents = result .into_iter() .map(scored_point_into_document) .collect::<Result<Vec<_>>>()?; Ok(query.retrieved_documents(documents)) } } fn scored_point_into_document(scored_point: ScoredPoint) -> Result<Document> { let content = scored_point .payload .get("content") .context("Expected document in qdrant payload")? .to_string(); let metadata: Metadata = scored_point .payload .into_iter() .filter(|(k, _)| *k != "content") .collect::<Vec<(_, _)>>() .into(); Ok(Document::new(content, Some(metadata))) } #[cfg(test)] mod tests { use itertools::Itertools as _; use swiftide_core::{ Persist as _, indexing::{self, EmbeddedField}, }; use super::*; async fn setup() -> ( testcontainers::ContainerAsync<testcontainers::GenericImage>, Qdrant, ) { let (guard, qdrant_url) = swiftide_test_utils::start_qdrant().await; let qdrant_client = Qdrant::try_from_url(qdrant_url) .unwrap() .vector_size(384) .with_vector(EmbeddedField::Combined) .with_sparse_vector(EmbeddedField::Combined) .build() .unwrap(); qdrant_client.setup().await.unwrap(); let nodes = vec![ indexing::TextNode::new("test_query1").with_metadata(("filter", "true")), indexing::TextNode::new("test_query2").with_metadata(("filter", "true")), indexing::TextNode::new("test_query3").with_metadata(("filter", "false")), ] .into_iter() .map(|node| { node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); node.with_sparse_vectors([( EmbeddedField::Combined, swiftide_core::SparseEmbedding { indices: vec![0, 1], values: vec![1.0, 1.0], }, )]); node.to_owned() }) .collect(); qdrant_client .batch_store(nodes) .await .try_collect::<Vec<_>>() .await .unwrap(); (guard, qdrant_client) } #[test_log::test(tokio::test)] async fn test_retrieve_multiple_docs_and_filter() { let (_guard, qdrant_client) = setup().await; let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); let search_strategy = SimilaritySingleEmbedding::<()>::default(); let result = qdrant_client .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 3); assert_eq!( result .documents() .iter() .sorted() .map(Document::content) .collect_vec(), // FIXME: The extra quotes should be removed by serde (via qdrant::Value), but they are // not ["\"test_query1\"", "\"test_query2\"", "\"test_query3\""] .into_iter() .sorted() .collect_vec() ); let search_strategy = SimilaritySingleEmbedding::from_filter(qdrant::Filter::must([ qdrant::Condition::matches("filter", "true".to_string()), ])); let result = qdrant_client .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 2); assert_eq!( result .documents() .iter() .sorted() .map(Document::content) .collect_vec(), ["\"test_query1\"", "\"test_query2\""] .into_iter() .sorted() .collect_vec() ); let search_strategy = SimilaritySingleEmbedding::from_filter(qdrant::Filter::must([ qdrant::Condition::matches("filter", "banana".to_string()), ])); let result = qdrant_client .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 0); } #[tokio::test] async fn test_hybrid_search() { let (_guard, qdrant_client) = setup().await; let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); query.sparse_embedding = Some(swiftide_core::SparseEmbedding { indices: vec![0, 1], values: vec![1.0, 1.0], }); let search_strategy = HybridSearch::default(); let result = qdrant_client .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 3); } #[tokio::test] async fn test_hybrid_search_with_filter() { let (_guard, qdrant_client) = setup().await; let mut query = Query::<states::Pending>::new("test_query"); query.embedding = Some(vec![1.0; 384]); query.sparse_embedding = Some(swiftide_core::SparseEmbedding { indices: vec![0, 1], values: vec![1.0, 1.0], }); let search_strategy = HybridSearch::from_filter(qdrant::Filter::must([qdrant::Condition::matches( "filter", "true".to_string(), )])); let result = qdrant_client .retrieve(&search_strategy, query.clone()) .await .unwrap(); assert_eq!(result.documents().len(), 2); } } ================================================ FILE: swiftide-integrations/src/redb/mod.rs ================================================ //! Redb is a simple, portable, high-performance, ACID, embedded key-value store. //! //! Redb can be used as a fast, embedded node cache, without the need for external services. use anyhow::Result; use std::{path::PathBuf, sync::Arc}; use derive_builder::Builder; mod node_cache; /// `Redb` provides a caching filter for indexing nodes using Redb. /// /// Redb is a simple, portable, high-performance, ACID, embedded key-value store. /// It enables using a local file based cache without the need for external services. /// /// # Example /// /// ```no_run /// # use swiftide_integrations::redb::{Redb}; /// Redb::builder() /// .database_path("/my/redb") /// .table_name("swiftide_test") /// .cache_key_prefix("my_cache") /// .build().unwrap(); /// ``` #[derive(Clone, Builder)] #[builder(build_fn(error = "anyhow::Error"), setter(into))] pub struct Redb { /// The database to use for caching nodes. Allows overwriting the default database created from /// `database_path`. #[builder(setter(into), default = "Arc::new(self.default_database()?)")] database: Arc<redb::Database>, /// Path to the database, required if no database override is provided. This is the recommended /// usage. #[builder(setter(into, strip_option))] database_path: Option<PathBuf>, /// The name of the table to use for caching nodes. Defaults to "swiftide". #[builder(default = "\"swiftide\".to_string()")] table_name: String, /// Prefix to be used for keys stored in the database to avoid collisions. Can be used to /// manually invalidate the cache. #[builder(default = "String::new()")] cache_key_prefix: String, } impl std::fmt::Debug for Redb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Redb") .field("database", &self.database) .field("database_path", &self.database_path) .field("table_name", &self.table_name) .field("cache_key_prefix", &self.cache_key_prefix) .finish() } } impl RedbBuilder { fn default_database(&self) -> Result<redb::Database> { let db = redb::Database::create( self.database_path .clone() .flatten() .ok_or(anyhow::anyhow!("Expected database path"))?, )?; Ok(db) } } impl Redb { pub fn builder() -> RedbBuilder { RedbBuilder::default() } pub fn node_key(&self, node: &swiftide_core::indexing::TextNode) -> String { format!("{}.{}", self.cache_key_prefix, node.id()) } pub fn table_definition(&self) -> redb::TableDefinition<'_, String, bool> { redb::TableDefinition::<String, bool>::new(&self.table_name) } pub fn database(&self) -> &redb::Database { &self.database } } ================================================ FILE: swiftide-integrations/src/redb/node_cache.rs ================================================ use anyhow::Result; use async_trait::async_trait; use redb::ReadableDatabase; use swiftide_core::{NodeCache, indexing::TextNode}; use super::Redb; // Simple proc macro that gets the ok value of a result or logs the error and returns false (not // cached) // // The underlying issue is that redb can be fickly if panics happened. We just want to make sure it // does not become worse. There probably is a better solution. macro_rules! unwrap_or_log { ($result:expr) => { match $result { Ok(value) => value, Err(e) => { tracing::error!("Error: {:#}", e); debug_assert!( true, "Redb should not give errors unless in very weird situations; this is a bug: {:#}", e ); return false; } } }; } #[async_trait] impl NodeCache for Redb { type Input = String; async fn get(&self, node: &TextNode) -> bool { let table_definition = self.table_definition(); let read_txn = unwrap_or_log!(self.database.begin_read()); let result = read_txn.open_table(table_definition); let table = match result { Ok(table) => table, Err(redb::TableError::TableDoesNotExist { .. }) => { // Create the table { let write_txn = unwrap_or_log!(self.database.begin_write()); unwrap_or_log!(write_txn.open_table(table_definition)); unwrap_or_log!(write_txn.commit()); } let read_tx = unwrap_or_log!(self.database.begin_read()); unwrap_or_log!(read_tx.open_table(table_definition)) } Err(e) => { tracing::error!("Failed to open table: {e:#}"); return false; } }; match table.get(self.node_key(node)).unwrap() { Some(access_guard) => access_guard.value(), None => false, } } async fn set(&self, node: &TextNode) { let write_txn = self.database.begin_write().unwrap(); { let mut table = write_txn.open_table(self.table_definition()).unwrap(); table.insert(self.node_key(node), true).unwrap(); } write_txn.commit().unwrap(); } /// Deletes the full cache table from the database. async fn clear(&self) -> Result<()> { let write_txn = self.database.begin_write().unwrap(); let _ = write_txn.delete_table(self.table_definition()); write_txn.commit().unwrap(); Ok(()) } } #[cfg(test)] mod tests { use super::*; use swiftide_core::indexing::TextNode; use temp_dir::TempDir; fn setup_redb() -> Redb { let tempdir = TempDir::new().unwrap(); Redb::builder() .database_path(tempdir.child("test_clear")) .build() .unwrap() } #[tokio::test] async fn test_get_set() { let redb = setup_redb(); let node = TextNode::new("test_get_set"); assert!(!redb.get(&node).await); redb.set(&node).await; assert!(redb.get(&node).await); } #[tokio::test] async fn test_clear() { let redb = setup_redb(); let node = TextNode::new("test_clear"); redb.set(&node).await; assert!(redb.get(&node).await); redb.clear().await.unwrap(); assert!(!redb.get(&node).await); } } ================================================ FILE: swiftide-integrations/src/redis/message_history.rs ================================================ use anyhow::{Context as _, Result}; use async_trait::async_trait; use swiftide_core::{MessageHistory, chat_completion::ChatMessage, indexing::Chunk}; use super::Redis; #[async_trait] impl<T: Chunk> MessageHistory for Redis<T> { async fn history(&self) -> Result<Vec<ChatMessage>> { if let Some(mut cm) = self.lazy_connect().await { let messages: Vec<String> = redis::cmd("LRANGE") .arg(&self.message_history_key) .arg(0) .arg(-1) .query_async(&mut cm) .await .context("Error fetching message history")?; messages .into_iter() .map(|msg| { serde_json::from_str::<ChatMessage>(&msg).context("Error deserializing message") }) .collect() } else { anyhow::bail!("Failed to connect to Redis") } } async fn push_owned(&self, item: ChatMessage) -> Result<()> { if let Some(mut cm) = self.lazy_connect().await { redis::cmd("RPUSH") .arg(&self.message_history_key) .arg(serde_json::to_string(&item)?) .query_async::<()>(&mut cm) .await .context("Error pushing to message history")?; Ok(()) } else { anyhow::bail!("Failed to connect to Redis") } } async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> { if let Some(mut cm) = self.lazy_connect().await { if items.is_empty() { return Ok(()); } redis::cmd("RPUSH") .arg(&self.message_history_key) .arg(serialize_messages(items)?) .query_async::<()>(&mut cm) .await .context("Error pushing to message history")?; Ok(()) } else { anyhow::bail!("Failed to connect to Redis") } } async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> { if let Some(mut cm) = self.lazy_connect().await { // If it does not exist yet, we can just push the items let _ = redis::cmd("DEL") .arg(&self.message_history_key) .query_async::<()>(&mut cm) .await; if items.is_empty() { // If we are overwriting with an empty history, we can just return return Ok(()); } redis::cmd("RPUSH") .arg(&self.message_history_key) .arg(serialize_messages(items)?) .query_async::<()>(&mut cm) .await .context("Error pushing to message history")?; Ok(()) } else { anyhow::bail!("Failed to connect to Redis") } } } fn serialize_messages(items: Vec<ChatMessage>) -> Result<Vec<String>> { items .into_iter() .map(|item| serde_json::to_string(&item).context("Error serializing message")) .collect() } #[cfg(test)] mod tests { use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner as _}; use super::*; async fn start_redis() -> (String, ContainerAsync<GenericImage>) { let redis_container = testcontainers::GenericImage::new("redis", "7.2.4") .with_exposed_port(6379.into()) .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( "Ready to accept connections", )) .start() .await .expect("Redis started"); let host = redis_container.get_host().await.unwrap(); let port = redis_container.get_host_port_ipv4(6379).await.unwrap(); let url = format!("redis://{host}:{port}/"); (url, redis_container) } #[tokio::test] async fn test_no_messages_yet() { let (url, _container) = start_redis().await; let redis = Redis::try_from_url(url, "tests").unwrap(); let messages = redis.history().await.unwrap(); assert!( messages.is_empty(), "Expected history to be empty for new Redis key" ); } #[tokio::test] async fn test_adding_and_next_completions() { let (url, _container) = start_redis().await; let redis = Redis::try_from_url(url, "tests").unwrap(); let m1 = ChatMessage::new_system("System test"); let m2 = ChatMessage::User("User test".into()); redis.push_owned(m1.clone()).await.unwrap(); redis.push_owned(m2.clone()).await.unwrap(); let hist = redis.history().await.unwrap(); assert_eq!( hist, vec![m1.clone(), m2.clone()], "History should match what's pushed" ); let hist2 = redis.history().await.unwrap(); assert_eq!( hist2, vec![m1, m2], "History should be unchanged on repeated call" ); } #[tokio::test] async fn test_overwrite_history() { let (url, _container) = start_redis().await; let redis = Redis::try_from_url(url, "tests").unwrap(); // Check that overwrite on empty also works redis.overwrite(vec![]).await.unwrap(); let m1 = ChatMessage::new_system("First"); let m2 = ChatMessage::User("Second".into()); redis.push_owned(m1.clone()).await.unwrap(); redis.push_owned(m2.clone()).await.unwrap(); let m3 = ChatMessage::new_assistant(Some("Overwritten".to_string()), None); redis.overwrite(vec![m3.clone()]).await.unwrap(); let hist = redis.history().await.unwrap(); assert_eq!( hist, vec![m3], "History should only contain the overwritten message" ); } #[tokio::test] async fn test_extend() { let (url, _container) = start_redis().await; let redis = Redis::try_from_url(url, "tests").unwrap(); let m1 = ChatMessage::new_system("First"); let m2 = ChatMessage::User("Second".into()); redis.push_owned(m1.clone()).await.unwrap(); let m3 = ChatMessage::new_assistant(Some("Third".to_string()), None); redis .extend_owned(vec![m2.clone(), m3.clone()]) .await .unwrap(); let hist = redis.history().await.unwrap(); assert_eq!(hist, vec![m1, m2, m3], "History should append on extend"); } } ================================================ FILE: swiftide-integrations/src/redis/mod.rs ================================================ //! This module provides the integration with Redis for caching nodes in the Swiftide system. //! //! The primary component of this module is the `Redis`, which is re-exported for use //! in other parts of the system. The `Redis` struct is responsible for managing and //! caching nodes during the indexing process, leveraging Redis for efficient storage and retrieval. //! //! # Overview //! //! Redis implements the following `Swiftide` traits: //! - `Node<T>Cache` //! - `Persist` //! - `MessageHistory` //! //! Additionally it provides various helper and utility functions for managing the Redis connection //! and key management. The connection is managed using a connection manager. When //! cloned, the connection manager is shared across all instances. use std::sync::Arc; use anyhow::{Context as _, Result}; use derive_builder::Builder; use serde::Serialize; use tokio::sync::RwLock; use swiftide_core::indexing::{Chunk, Node}; mod message_history; mod node_cache; mod persist; /// `Redis` provides a caching mechanism for nodes using Redis. /// It helps in optimizing the indexing process by skipping nodes that have already been processed. /// /// # Fields /// /// * `client` - The Redis client used to interact with the Redis server. /// * `connection_manager` - Manages the Redis connections asynchronously. /// * `key_prefix` - A prefix used for keys stored in Redis to avoid collisions. #[allow(clippy::type_complexity)] #[derive(Builder, Clone)] #[builder(pattern = "owned", setter(strip_option))] pub struct Redis<T: Chunk = String> { #[builder(setter(into))] client: Arc<redis::Client>, #[builder(default, setter(skip))] connection_manager: Arc<RwLock<Option<redis::aio::ConnectionManager>>>, #[builder(default, setter(into))] cache_key_prefix: Arc<String>, #[builder(default = "10")] /// The batch size used for persisting nodes. Defaults to a safe 10. batch_size: usize, #[builder(default)] /// Customize the key used for persisting nodes persist_key_fn: Option<fn(&Node<T>) -> Result<String>>, #[builder(default)] /// Customize the value used for persisting nodes persist_value_fn: Option<fn(&Node<T>) -> Result<String>>, #[builder(default = "message_history".to_string().into(), setter(into))] message_history_key: Arc<String>, } impl Redis<String> { /// Creates a new `Redis` instance from a given Redis URL and key prefix. /// /// # Parameters /// /// * `url` - The URL of the Redis server. /// * `prefix` - The prefix to be used for keys stored in Redis. /// /// # Returns /// /// A `Result` containing the `Redis` instance or an error if the client could not be created. /// /// # Errors /// /// Returns an error if the Redis client cannot be opened. pub fn try_from_url(url: impl AsRef<str>, prefix: impl AsRef<str>) -> Result<Redis<String>> { let client = redis::Client::open(url.as_ref()).context("Failed to open redis client")?; Ok(Redis::<String> { client: client.into(), connection_manager: Arc::new(RwLock::new(None)), cache_key_prefix: prefix.as_ref().to_string().into(), batch_size: 10, persist_key_fn: None, persist_value_fn: None, message_history_key: format!("{}:message_history", prefix.as_ref()).into(), }) } } impl<T: Chunk> Redis<T> { /// # Errors /// /// Returns an error if the Redis client cannot be opened pub fn try_build_from_url(url: impl AsRef<str>) -> Result<RedisBuilder<T>> { Ok(RedisBuilder::default() .client(redis::Client::open(url.as_ref()).context("Failed to open redis client")?)) } /// Builds a new `Redis` instance from the builder. pub fn builder() -> RedisBuilder<T> { RedisBuilder::default() } /// Set the key to be used for the message history pub fn with_message_history_key(&mut self, prefix: impl Into<String>) -> &mut Self { self.message_history_key = Arc::new(prefix.into()); self } /// Lazily connects to the Redis server and returns the connection manager. /// /// # Returns /// /// An `Option` containing the `ConnectionManager` if the connection is successful, or `None` if /// it fails. /// /// # Errors /// /// Logs an error and returns `None` if the connection manager cannot be obtained. async fn lazy_connect(&self) -> Option<redis::aio::ConnectionManager> { if self.connection_manager.read().await.is_none() { let result = self.client.get_connection_manager().await; if let Err(e) = result { tracing::error!("Failed to get connection manager: {}", e); return None; } let mut cm = self.connection_manager.write().await; *cm = result.ok(); } self.connection_manager.read().await.clone() } /// Generates a Redis key for a given node using the key prefix and the node's hash. /// /// # Parameters /// /// * `node` - The node for which the key is to be generated. /// /// # Returns /// /// A `String` representing the Redis key for the node. fn cache_key_for_node(&self, node: &Node<T>) -> String { format!("{}:{}", self.cache_key_prefix, node.id()) } /// Generates a key for a given node to be persisted in Redis. fn persist_key_for_node(&self, node: &Node<T>) -> Result<String> { if let Some(key_fn) = self.persist_key_fn { key_fn(node) } else { let hash = node.id(); Ok(format!("{}:{}", node.path.to_string_lossy(), hash)) } } /// Resets the cache by deleting all keys with the specified prefix. /// This function is intended for testing purposes and is inefficient for production use. /// /// # Errors /// /// Panics if the keys cannot be retrieved or deleted. #[allow(dead_code)] async fn reset_cache(&self) { if let Some(mut cm) = self.lazy_connect().await { let keys: Vec<String> = redis::cmd("KEYS") .arg(format!("{}:*", self.cache_key_prefix)) .query_async(&mut cm) .await .expect("Could not get keys"); for key in &keys { let _: usize = redis::cmd("DEL") .arg(key) .query_async(&mut cm) .await .expect("Failed to reset cache"); } } } /// Gets a node persisted in Redis using the GET command /// Takes a node and returns a Result<Option<String>> #[allow(dead_code)] async fn get_node(&self, node: &Node<T>) -> Result<Option<String>> { if let Some(mut cm) = self.lazy_connect().await { let key = self.persist_key_for_node(node)?; let result: Option<String> = redis::cmd("GET") .arg(key) .query_async(&mut cm) .await .context("Error getting from redis")?; Ok(result) } else { anyhow::bail!("Failed to connect to Redis") } } } impl<T: Chunk + Serialize> Redis<T> { /// Generates a value for a given node to be persisted in Redis. /// By default, the node is serialized as JSON. /// If a custom function is provided, it is used to generate the value. /// Otherwise, the node is serialized as JSON. fn persist_value_for_node(&self, node: &Node<T>) -> Result<String> { if let Some(value_fn) = self.persist_value_fn { value_fn(node) } else { Ok(serde_json::to_string(node)?) } } } // Redis CM does not implement debug #[allow(clippy::missing_fields_in_debug)] impl<T: Chunk> std::fmt::Debug for Redis<T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Redis") .field("client", &self.client) .finish() } } ================================================ FILE: swiftide-integrations/src/redis/node_cache.rs ================================================ use anyhow::Result; use async_trait::async_trait; use swiftide_core::indexing::{Chunk, Node, NodeCache}; use super::Redis; #[allow(dependency_on_unit_never_type_fallback)] #[async_trait] impl<T: Chunk> NodeCache for Redis<T> { type Input = T; /// Checks if a node is present in the cache. /// /// # Parameters /// /// * `node` - The node to be checked in the cache. /// /// # Returns /// /// `true` if the node is present in the cache, `false` otherwise. /// /// # Errors /// /// Logs an error and returns `false` if the cache check fails. #[tracing::instrument(skip_all, fields(hit), level = "trace")] async fn get(&self, node: &Node<T>) -> bool { let cache_result = if let Some(mut cm) = self.lazy_connect().await { let result = redis::cmd("EXISTS") .arg(self.cache_key_for_node(node)) .query_async(&mut cm) .await; match result { Ok(1) => true, Ok(0) => false, Err(e) => { tracing::error!("Failed to check node cache: {}", e); false } _ => { tracing::error!("Unexpected response from redis"); false } } } else { false }; tracing::Span::current().record("hit", cache_result); cache_result } /// Sets a node in the cache. /// /// # Parameters /// /// * `node` - The node to be set in the cache. /// /// # Errors /// /// Logs an error if the node cannot be set in the cache. #[tracing::instrument(skip_all, level = "trace")] async fn set(&self, node: &Node<T>) { if let Some(mut cm) = self.lazy_connect().await { let result: Result<(), redis::RedisError> = redis::cmd("SET") .arg(self.cache_key_for_node(node)) .arg(1) .query_async(&mut cm) .await; if let Err(e) = result { tracing::error!("Failed to set node cache: {}", e); } } } async fn clear(&self) -> Result<()> { if self.cache_key_prefix.is_empty() { return Err(anyhow::anyhow!( "No cache key prefix set; not flushing cache" )); } if let Some(mut cm) = self.lazy_connect().await { redis::cmd("DEL") .arg(format!("{}*", self.cache_key_prefix)) .query_async::<()>(&mut cm) .await?; Ok(()) } else { anyhow::bail!("Failed to connect to Redis"); } } } #[cfg(test)] mod tests { use super::*; use swiftide_core::indexing::TextNode; use testcontainers::runners::AsyncRunner; /// Tests the `RedisNodeCache` implementation. #[test_log::test(tokio::test)] async fn test_redis_cache() { let redis = testcontainers::GenericImage::new("redis", "7.2.4") .with_exposed_port(6379.into()) .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( "Ready to accept connections", )) .start() .await .expect("Redis started"); let host = redis.get_host().await.unwrap(); let port = redis.get_host_port_ipv4(6379).await.unwrap(); let cache = Redis::try_from_url(format!("redis://{host}:{port}"), "test") .expect("Could not build redis client"); cache.reset_cache().await; let node = TextNode::new("chunk"); let before_cache = cache.get(&node).await; assert!(!before_cache); cache.set(&node).await; let after_cache = cache.get(&node).await; assert!(after_cache); } } ================================================ FILE: swiftide-integrations/src/redis/persist.rs ================================================ use anyhow::{Context as _, Result}; use async_trait::async_trait; use serde::Serialize; use swiftide_core::{ Persist, indexing::{Chunk, IndexingStream, Node}, }; use super::Redis; #[async_trait] #[allow(dependency_on_unit_never_type_fallback)] impl<T: Chunk + Serialize> Persist for Redis<T> { type Input = T; type Output = T; async fn setup(&self) -> Result<()> { Ok(()) } fn batch_size(&self) -> Option<usize> { Some(self.batch_size) } /// Stores a node in Redis using the SET command. /// /// By default nodes are stored with the path and hash as key and the node serialized as JSON as /// value. /// /// You can customize the key and value used for storing nodes by setting the `persist_key_fn` /// and `persist_value_fn` fields. async fn store(&self, node: Node<T>) -> Result<Node<T>> { if let Some(mut cm) = self.lazy_connect().await { redis::cmd("SET") .arg(self.persist_key_for_node(&node)?) .arg(self.persist_value_for_node(&node)?) .query_async::<()>(&mut cm) .await .context("Error persisting to redis")?; Ok(node) } else { anyhow::bail!("Failed to connect to Redis") } } /// Stores a batch of nodes in Redis using the MSET command. /// /// By default nodes are stored with the path and hash as key and the node serialized as JSON as /// value. /// /// You can customize the key and value used for storing nodes by setting the `persist_key_fn` /// and `persist_value_fn` fields. async fn batch_store(&self, nodes: Vec<Node<T>>) -> IndexingStream<T> { // use mset for batch store if let Some(mut cm) = self.lazy_connect().await { let args = match nodes .iter() .map(|node| -> Result<Vec<String>> { let key = self.persist_key_for_node(node)?; let value = self.persist_value_for_node(node)?; Ok(vec![key, value]) }) .collect::<Result<Vec<_>>>() { Ok(args) => args, Err(err) => return vec![Err(err)].into(), }; let result: Result<()> = redis::cmd("MSET") .arg(args) .query_async(&mut cm) .await .context("Error persisting to redis"); if let Err(e) = result { IndexingStream::iter([Err(e)]) } else { IndexingStream::iter(nodes.into_iter().map(Ok)) } } else { IndexingStream::iter([Err(anyhow::anyhow!("Failed to connect to Redis"))]) } } } #[cfg(test)] mod tests { use super::*; use futures_util::TryStreamExt; use swiftide_core::indexing::TextNode; use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner}; async fn start_redis() -> ContainerAsync<GenericImage> { testcontainers::GenericImage::new("redis", "7.2.4") .with_exposed_port(6379.into()) .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( "Ready to accept connections", )) .start() .await .expect("Redis started") } #[test_log::test(tokio::test)] async fn test_redis_persist() { let redis_container = start_redis().await; let host = redis_container.get_host().await.unwrap(); let port = redis_container.get_host_port_ipv4(6379).await.unwrap(); let redis = Redis::try_build_from_url(format!("redis://{host}:{port}")) .unwrap() .build() .unwrap(); let node = TextNode::new("chunk"); redis.store(node.clone()).await.unwrap(); let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap()); assert_eq!(node, stored_node.unwrap()); } // test batch store #[test_log::test(tokio::test)] async fn test_redis_batch_persist() { let redis_container = start_redis().await; let host = redis_container.get_host().await.unwrap(); let port = redis_container.get_host_port_ipv4(6379).await.unwrap(); let redis = Redis::try_build_from_url(format!("redis://{host}:{port}")) .unwrap() .batch_size(20) .build() .unwrap(); let nodes = vec![TextNode::new("test"), TextNode::new("other")]; let stream = redis.batch_store(nodes).await; let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap(); assert_eq!(streamed_nodes.len(), 2); for node in streamed_nodes { let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap()); assert_eq!(node, stored_node.unwrap()); } } #[test_log::test(tokio::test)] async fn test_redis_custom_persist() { let redis_container = start_redis().await; let host = redis_container.get_host().await.unwrap(); let port = redis_container.get_host_port_ipv4(6379).await.unwrap(); let redis = Redis::<String>::try_build_from_url(format!("redis://{host}:{port}")) .unwrap() .persist_key_fn(|_node| Ok("test".to_string())) .persist_value_fn(|_node| Ok("hello world".to_string())) .build() .unwrap(); let node = Node::default(); redis.store(node.clone()).await.unwrap(); let stored_node = redis.get_node(&node).await.unwrap(); assert_eq!(stored_node.unwrap(), "hello world"); assert_eq!( redis.persist_key_for_node(&node).unwrap(), "test".to_string() ); } } ================================================ FILE: swiftide-integrations/src/scraping/html_to_markdown_transformer.rs ================================================ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use htmd::HtmlToMarkdown; use swiftide_core::{Transformer, indexing::TextNode}; /// Transforms HTML content into markdown. /// /// Useful for converting scraping results into markdown. #[swiftide_macros::indexing_transformer(derive(skip_default, skip_debug))] pub struct HtmlToMarkdownTransformer { /// The `HtmlToMarkdown` instance used to convert HTML to markdown. /// /// Sets a sane default, but can be customized. htmd: Arc<HtmlToMarkdown>, } impl Default for HtmlToMarkdownTransformer { fn default() -> Self { Self { htmd: HtmlToMarkdown::builder() .skip_tags(vec!["script", "style"]) .build() .into(), concurrency: None, client: None, indexing_defaults: None, } } } impl std::fmt::Debug for HtmlToMarkdownTransformer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("HtmlToMarkdownTransformer").finish() } } #[async_trait] impl Transformer for HtmlToMarkdownTransformer { type Input = String; type Output = String; /// Converts the HTML content in the `TextNode` to markdown. /// /// Will Err the node if the conversion fails. #[tracing::instrument(skip_all, name = "transformer.html_to_markdown")] async fn transform_node(&self, node: TextNode) -> Result<TextNode> { let chunk = self.htmd.convert(&node.chunk)?; TextNode::build_from_other(&node).chunk(chunk).build() } fn concurrency(&self) -> Option<usize> { self.concurrency } } #[cfg(test)] mod test { use super::*; #[tokio::test] async fn test_html_to_markdown() { let node = TextNode::new("<h1>Hello, World!</h1>"); let transformer = HtmlToMarkdownTransformer::default(); let transformed_node = transformer.transform_node(node).await.unwrap(); assert_eq!(transformed_node.chunk, "# Hello, World!"); } } ================================================ FILE: swiftide-integrations/src/scraping/loader.rs ================================================ use derive_builder::Builder; use spider::website::Website; use swiftide_core::{ Loader, indexing::{IndexingStream, TextNode}, }; #[derive(Debug, Builder, Clone)] #[builder(pattern = "owned")] /// Scrapes a given website /// /// Under the hood uses the `spider` crate to scrape the website. /// For more configuration options see their documentation. pub struct ScrapingLoader { spider_website: Website, } impl ScrapingLoader { pub fn builder() -> ScrapingLoaderBuilder { ScrapingLoaderBuilder::default() } // Constructs a scrapingloader from a `spider::Website` configuration #[allow(dead_code)] pub fn from_spider(spider_website: Website) -> Self { Self { spider_website } } /// Constructs a scrapingloader from a given url pub fn from_url(url: impl AsRef<str>) -> Self { Self::from_spider(Website::new(url.as_ref())) } } impl Loader for ScrapingLoader { type Output = String; fn into_stream(mut self) -> IndexingStream<String> { let (tx, rx) = tokio::sync::mpsc::channel(1000); let mut spider_rx = self .spider_website .subscribe(0) .expect("Failed to subscribe to spider"); tracing::info!("Subscribed to spider"); let _recv_thread = tokio::spawn(async move { while let Ok(res) = spider_rx.recv().await { let html = res.get_html(); let original_size = html.len(); let node = TextNode::builder() .chunk(html) .original_size(original_size) .path(res.get_url()) .build(); tracing::debug!(?node, "[Spider] Received node from spider"); if let Err(error) = tx.send(node).await { tracing::error!(?error, "[Spider] Failed to send node to stream"); break; } } }); let mut spider_website = self.spider_website; let _scrape_thread = tokio::spawn(async move { tracing::info!("[Spider] Starting scrape loop"); // TODO: It would be much nicer if this used `scrape` instead, as it is supposedly // more concurrent spider_website.crawl().await; tracing::info!("[Spider] Scrape loop finished"); }); // NOTE: Handles should stay alive because of rx, but feels a bit fishy rx.into() } fn into_stream_boxed(self: Box<Self>) -> IndexingStream<String> { self.into_stream() } } #[cfg(test)] mod tests { use super::*; use anyhow::Result; use futures_util::StreamExt; use swiftide_core::indexing::Loader; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, Request, ResponseTemplate}; #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_scraping_loader_with_wiremock() { // Set up the wiremock server to simulate the remote web server let mock_server = MockServer::start().await; // Mocked response for the page we will scrape let body = "<html><body><h1>Test Page</h1></body></html>"; Mock::given(method("GET")) .and(path("/")) .respond_with(ResponseTemplate::new(200).set_body_string(body)) .mount(&mock_server) .await; // Create an instance of ScrapingLoader using the mock server's URL let loader = ScrapingLoader::from_url(mock_server.uri()); // Execute the into_stream method let stream = loader.into_stream(); // Process the stream to check if we get the expected result let nodes = stream.collect::<Vec<Result<TextNode>>>().await; assert_eq!(nodes.len(), 1); let first_node = nodes.first().unwrap().as_ref().unwrap(); assert_eq!(first_node.chunk, body); } #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn test_scraping_loader_multiple_pages() { // Set up the wiremock server to simulate the remote web server let mock_server = MockServer::start().await; // Mocked response for the page we will scrape let body = "<html><body><h1>Test Page</h1><a href=\"/other\">link</a></body></html>"; Mock::given(method("GET")) .and(path("/")) .respond_with(ResponseTemplate::new(200).set_body_string(body)) .mount(&mock_server) .await; let body2 = "<html><body><h1>Test Page 2</h1></body></html>"; Mock::given(method("GET")) .and(path("/other")) .respond_with(move |_req: &Request| { std::thread::sleep(std::time::Duration::from_secs(1)); ResponseTemplate::new(200).set_body_string(body2) }) .mount(&mock_server) .await; // Create an instance of ScrapingLoader using the mock server's URL let loader = ScrapingLoader::from_url(mock_server.uri()); // Execute the into_stream method let stream = loader.into_stream(); // Process the stream to check if we get the expected result let mut nodes = stream.collect::<Vec<Result<TextNode>>>().await; assert_eq!(nodes.len(), 2); let first_node = nodes.pop().unwrap().unwrap(); assert_eq!(first_node.chunk, body2); let second_node = nodes.pop().unwrap().unwrap(); assert_eq!(second_node.chunk, body); } } ================================================ FILE: swiftide-integrations/src/scraping/mod.rs ================================================ //! Scraping loader using and html to markdown transformer mod html_to_markdown_transformer; mod loader; pub use html_to_markdown_transformer::HtmlToMarkdownTransformer; pub use loader::ScrapingLoader; ================================================ FILE: swiftide-integrations/src/tiktoken/mod.rs ================================================ //! Use tiktoken-rs to estimate token count on various common Swiftide types //! //! Intended to be used for openai models. //! //! Note that the library is heavy on the unwraps. use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use swiftide_core::token_estimation::{Estimatable, EstimateTokens}; use tiktoken_rs::{CoreBPE, get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer}; /// A tiktoken based tokenizer for openai models. Can also be used for other models. /// /// Implements `EstimateTokens` for various swiftide types (prompts, chat messages, lists of chat /// messages) and regular strings. /// /// Estimates are estimates; not exact counts. /// /// # Example /// /// ```no_run /// # use swiftide_core::token_estimation::EstimateTokens; /// # use swiftide_integrations::tiktoken::TikToken; /// /// # async fn test() { /// let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap(); /// let estimate = tokenizer.estimate("hello {{world}}").await.unwrap(); /// /// assert_eq!(estimate, 4); /// # } /// ``` #[derive(Clone)] pub struct TikToken { /// The tiktoken model to use bpe: Arc<CoreBPE>, } impl std::fmt::Debug for TikToken { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TikToken").finish() } } impl Default for TikToken { fn default() -> Self { Self::try_from_model("gpt-4o") .expect("infallible; gpt-4o should be valid model for tiktoken") } } impl TikToken { /// Build a `TikToken` from an openai model name /// /// # Errors /// /// Errors if the tokenizer cannot be found from the model or it cannot be build pub fn try_from_model(model: impl AsRef<str>) -> Result<Self> { let bpe = get_bpe_from_model(model.as_ref())?; Ok(Self { bpe: Arc::new(bpe) }) } /// Build a `TikToken` from a `tiktoken_rs::tiktoken::Tokenizer` /// /// # Errors /// /// Errors if the tokenizer cannot be build pub fn try_from_tokenizer(tokenizer: Tokenizer) -> Result<Self> { let bpe = get_bpe_from_tokenizer(tokenizer)?; Ok(Self { bpe: Arc::new(bpe) }) } } #[async_trait] impl EstimateTokens for TikToken { async fn estimate(&self, value: impl Estimatable) -> Result<usize> { let mut total = 0; for text in value.for_estimate()? { total += self.bpe.encode_with_special_tokens(text.as_ref()).len(); } Ok(total + value.additional_tokens()) } } #[cfg(test)] mod tests { use swiftide_core::{chat_completion::ChatMessage, prompt::Prompt}; use super::*; #[tokio::test] async fn test_estimate_tokens() { let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap(); let prompt = Prompt::from("hello {{world}}"); let tokens = tokenizer.estimate(&prompt).await.unwrap(); assert_eq!(tokens, 4); } #[tokio::test] async fn test_estimate_tokens_from_tokenizer() { let tokenizer = TikToken::try_from_tokenizer(Tokenizer::O200kBase).unwrap(); let prompt = "hello {{world}}"; let tokens = tokenizer.estimate(prompt).await.unwrap(); assert_eq!(tokens, 4); } #[tokio::test] async fn test_estimate_chat_messages() { let messages = vec![ ChatMessage::new_user("hello ".repeat(10)), ChatMessage::new_system("world"), ]; // 11x hello + 1x world + 2x 4 per message + 1x 3 for full + 2 whatever = 23 let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap(); dbg!(messages.as_slice().for_estimate().unwrap()); assert_eq!(tokenizer.estimate(messages.as_slice()).await.unwrap(), 23); } } ================================================ FILE: swiftide-integrations/src/treesitter/chunk_code.rs ================================================ //! Chunk code using tree-sitter use anyhow::{Context as _, Result}; use async_trait::async_trait; use derive_builder::Builder; use crate::treesitter::{ChunkSize, CodeSplitter, SupportedLanguages}; use swiftide_core::{ ChunkerTransformer, indexing::{IndexingStream, TextNode}, }; /// The `ChunkCode` struct is responsible for chunking code into smaller pieces /// based on the specified language and chunk size. /// /// It uses tree-sitter under the hood, and tries to split the code into smaller, meaningful /// chunks. /// /// # Example /// /// ```no_run /// # use swiftide_integrations::treesitter::transformers::ChunkCode; /// # use swiftide_integrations::treesitter::SupportedLanguages; /// // Chunk rust code with a maximum chunk size of 1000 bytes. /// ChunkCode::try_for_language_and_chunk_size(SupportedLanguages::Rust, 1000); /// /// // Chunk python code with a minimum chunk size of 500 bytes and maximum chunk size of 2048. /// // Smaller chunks than 500 bytes will be discarded. /// ChunkCode::try_for_language_and_chunk_size(SupportedLanguages::Python, 500..2048); /// ```` #[derive(Debug, Clone, Builder)] #[builder(pattern = "owned", setter(into, strip_option))] pub struct ChunkCode { chunker: CodeSplitter, #[builder(default)] concurrency: Option<usize>, } impl ChunkCode { pub fn builder() -> ChunkCodeBuilder { ChunkCodeBuilder::default() } /// Tries to create a `ChunkCode` instance for a given programming language. /// /// # Parameters /// - `lang`: The programming language to be used for chunking. It should implement /// `TryInto<SupportedLanguages>`. /// /// # Returns /// - `Result<Self>`: Returns an instance of `ChunkCode` if successful, otherwise returns an /// error. /// /// # Errors /// - Returns an error if the language is not supported or if the `CodeSplitter` fails to build. pub fn try_for_language(lang: impl TryInto<SupportedLanguages>) -> Result<Self> { Ok(Self { chunker: CodeSplitter::builder().try_language(lang)?.build()?, concurrency: None, }) } /// Tries to create a `ChunkCode` instance for a given programming language and chunk size. /// /// # Parameters /// - `lang`: The programming language to be used for chunking. It should implement /// `TryInto<SupportedLanguages>`. /// - `chunk_size`: The size of the chunks. It should implement `Into<ChunkSize>`. /// /// # Returns /// - `Result<Self>`: Returns an instance of `ChunkCode` if successful, otherwise returns an /// error. /// /// # Errors /// - Returns an error if the language is not supported, if the chunk size is invalid, or if the /// `CodeSplitter` fails to build. pub fn try_for_language_and_chunk_size( lang: impl TryInto<SupportedLanguages>, chunk_size: impl Into<ChunkSize>, ) -> Result<Self> { Ok(Self { chunker: CodeSplitter::builder() .try_language(lang)? .chunk_size(chunk_size) .build()?, concurrency: None, }) } #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = Some(concurrency); self } } #[async_trait] impl ChunkerTransformer for ChunkCode { type Input = String; type Output = String; /// Transforms a `TextNode` by splitting its code chunk into smaller pieces. /// /// # Parameters /// - `node`: The `TextNode` containing the code chunk to be split. /// /// # Returns /// - `IndexingStream`: A stream of `TextNode` instances, each containing a smaller chunk of /// code. /// /// # Errors /// - If the code splitting fails, an error is sent downstream. #[tracing::instrument(skip_all, name = "transformers.chunk_code")] async fn transform_node(&self, node: TextNode) -> IndexingStream<String> { let split_result = self.chunker.split(&node.chunk); if let Ok(split) = split_result { let mut offset = 0; IndexingStream::iter(split.into_iter().map(move |chunk| { let chunk_size = chunk.len(); let node = TextNode::build_from_other(&node) .chunk(chunk) .offset(offset) .build(); offset += chunk_size; node })) } else { // Send the error downstream IndexingStream::iter(vec![Err(split_result .with_context(|| format!("Failed to chunk {}", node.path.display())) .unwrap_err())]) } } fn concurrency(&self) -> Option<usize> { self.concurrency } } ================================================ FILE: swiftide-integrations/src/treesitter/code_tree.rs ================================================ //! Code parsing //! //! Extracts typed semantics from code. #![allow(dead_code)] use itertools::Itertools; use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator as _, Tree}; use anyhow::{Context as _, Result}; use std::collections::HashSet; use crate::treesitter::queries::{ csharp, go, java, javascript, python, ruby, rust, solidity, typescript, }; use super::SupportedLanguages; #[derive(Debug)] pub struct CodeParser { language: SupportedLanguages, } impl CodeParser { pub fn from_language(language: SupportedLanguages) -> Self { Self { language } } /// Parses code and returns a `CodeTree` /// /// Tree-sitter is pretty lenient and will parse invalid code. I.e. if the code is invalid, /// queries might fail and return no results. /// /// This is good as it makes this safe to use for chunked code as well. /// /// # Errors /// /// Errors if the language is not support or if the tree cannot be parsed pub fn parse<'a>(&self, code: &'a str) -> Result<CodeTree<'a>> { let mut parser = Parser::new(); parser.set_language(&self.language.into())?; let ts_tree = parser.parse(code, None).context("No nodes found")?; Ok(CodeTree { ts_tree, code, language: self.language, }) } } /// A code tree is a queryable representation of code pub struct CodeTree<'a> { ts_tree: Tree, code: &'a str, language: SupportedLanguages, } pub struct ReferencesAndDefinitions { pub references: Vec<String>, pub definitions: Vec<String>, } impl CodeTree<'_> { /// Queries for references and definitions in the code. It returns a unique list of non-local /// references, and local definitions. /// /// # Errors /// /// Errors if the query is invalid or fails pub fn references_and_definitions(&self) -> Result<ReferencesAndDefinitions> { let (defs, refs) = ts_queries_for_language(self.language); let defs_query = Query::new(&self.language.into(), defs)?; let refs_query = Query::new(&self.language.into(), refs)?; let defs = self.ts_query_for_matches(&defs_query)?; let refs = self.ts_query_for_matches(&refs_query)?; Ok(ReferencesAndDefinitions { // Remove any self references references: refs .into_iter() .filter(|r| !defs.contains(r)) .sorted() .collect(), definitions: defs.into_iter().sorted().collect(), }) } /// Given a `tree-sitter` query, searches the code and returns a list of matching symbols fn ts_query_for_matches(&self, query: &Query) -> Result<HashSet<String>> { let mut cursor = QueryCursor::new(); cursor .matches(query, self.ts_tree.root_node(), self.code.as_bytes()) .map_deref(|m| { m.captures .iter() .map(|c| { Ok(c.node .utf8_text(self.code.as_bytes()) .context("Failed to parse node")? .to_string()) }) .collect::<Result<Vec<_>>>() .map(|s| s.join("")) }) .collect::<Result<HashSet<_>>>() } } fn ts_queries_for_language(language: SupportedLanguages) -> (&'static str, &'static str) { use SupportedLanguages::{ C, CSharp, Cpp, Elixir, Go, HTML, Java, Javascript, PHP, Python, Ruby, Rust, Solidity, Typescript, }; match language { Rust => (rust::DEFS, rust::REFS), Python => (python::DEFS, python::REFS), // The univocal proof that TS is just a linter Typescript => (typescript::DEFS, typescript::REFS), Javascript => (javascript::DEFS, javascript::REFS), Ruby => (ruby::DEFS, ruby::REFS), Java => (java::DEFS, java::REFS), Go => (go::DEFS, go::REFS), CSharp => (csharp::DEFS, csharp::REFS), Solidity => (solidity::DEFS, solidity::REFS), C | Cpp | Elixir | PHP | HTML => unimplemented!(), } } #[cfg(test)] mod tests { use super::*; #[test] fn test_parsing_on_rust() { let parser = CodeParser::from_language(SupportedLanguages::Rust); let code = r#" use std::io; fn main() { println!("Hello, world!"); } "#; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.references, vec!["println"]); assert_eq!(result.definitions, vec!["main"]); } #[test] fn test_parsing_on_solidity() { let parser = CodeParser::from_language(SupportedLanguages::Solidity); let code = r" pragma solidity ^0.8.0; contract MyContract { function myFunction() public { emit MyEvent(); } } "; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.references, vec!["MyEvent"]); assert_eq!(result.definitions, vec!["MyContract", "myFunction"]); } #[test] fn test_parsing_on_ruby() { let parser = CodeParser::from_language(SupportedLanguages::Ruby); let code = r#" class A < Inheritance include ActuallyAlsoInheritance def a puts "A" end end "#; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!( result.references, ["ActuallyAlsoInheritance", "Inheritance", "include", "puts",] ); assert_eq!(result.definitions, ["A", "a"]); } #[test] fn test_parsing_python() { // test with a python class and list comprehension let parser = CodeParser::from_language(SupportedLanguages::Python); let code = r#" class A: def __init__(self): self.a = [x for x in range(10)] def hello_world(): print("Hello, world!") "#; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.references, ["print", "range"]); assert_eq!(result.definitions, vec!["A", "hello_world"]); } #[test] fn test_parsing_on_c_sharp() { let parser = CodeParser::from_language(SupportedLanguages::CSharp); let code = r#" public class Greeter { public void SayHello() { System.Console.WriteLine("Hello, world!"); } } "#; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.references, vec!["WriteLine"]); assert_eq!(result.definitions, vec!["Greeter", "SayHello"]); } #[test] fn test_parsing_on_typescript() { let parser = CodeParser::from_language(SupportedLanguages::Typescript); let code = r#" function Test() { console.log("Hello, TypeScript!"); otherThing(); } class MyClass { constructor() { let local = 5; this.myMethod(); } myMethod() { console.log("Hello, TypeScript!"); } } "#; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]); assert_eq!(result.references, vec!["log", "otherThing"]); } #[test] fn test_parsing_on_javascript() { let parser = CodeParser::from_language(SupportedLanguages::Javascript); let code = r#" function Test() { console.log("Hello, JavaScript!"); otherThing(); } class MyClass { constructor() { let local = 5; this.myMethod(); } myMethod() { console.log("Hello, JavaScript!"); } } "#; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]); assert_eq!(result.references, vec!["log", "otherThing"]); } #[test] fn test_parsing_on_java() { let parser = CodeParser::from_language(SupportedLanguages::Java); let code = r#" public class Hello { public static void main(String[] args) { System.out.printf("Hello %s!%n", args[0]); } } "#; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.definitions, vec!["Hello", "main"]); assert_eq!(result.references, vec!["printf"]); } #[test] fn test_parsing_on_java_enum() { let parser = CodeParser::from_language(SupportedLanguages::Java); let code = r" enum Material { DENIM, CANVAS, SPANDEX_3_PERCENT } class Person { Person(string name) { this.name = name; this.pants = new Pants<Pocket>(); } String getName() { a = this.name; b = new one.two.Three(); c = Material.DENIM; } } "; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.definitions, vec!["Material", "Person", "getName"]); assert!(result.references.is_empty()); } #[test] fn test_parsing_go() { let parser = CodeParser::from_language(SupportedLanguages::Go); // hello world go with struct let code = r" package main type Person struct { name string age int } func main() { p := Person{name: 'John', age: 30} fmt.Println(p) } "; let tree = parser.parse(code).unwrap(); let result = tree.references_and_definitions().unwrap(); assert_eq!(result.references, vec!["Println", "int", "string"]); assert_eq!(result.definitions, vec!["Person", "main"]); } } ================================================ FILE: swiftide-integrations/src/treesitter/compress_code_outline.rs ================================================ //! `CompressCodeOutline` is a transformer that reduces the size of the outline of a the parent file //! of a chunk to make it more relevant to the chunk. use std::sync::OnceLock; use anyhow::Result; use async_trait::async_trait; use swiftide_core::{Transformer, indexing::TextNode}; /// `CompressCodeChunk` rewrites the "Outline" metadata field of a chunk to /// condense it and make it more relevant to the chunk in question. It is useful as a /// step after chunking a file that has had outline generated for it with `FileToOutlineTreeSitter`. #[swiftide_macros::indexing_transformer( metadata_field_name = "Outline", default_prompt_file = "prompts/compress_code_outline.prompt.md" )] pub struct CompressCodeOutline {} fn extract_markdown_codeblock(text: String) -> String { static REGEX: OnceLock<regex::Regex> = OnceLock::new(); let re = REGEX.get_or_init(|| regex::Regex::new(r"(?sm)```\w*\n(.*?)```").unwrap()); let captures = re.captures(text.as_str()); captures .map(|c| c.get(1).unwrap().as_str().to_string()) .unwrap_or(text) } #[async_trait] impl Transformer for CompressCodeOutline { type Input = String; type Output = String; /// Asynchronously transforms an `TextNode` by reducing the size of the outline to make it more /// relevant to the chunk. /// /// This method uses the `SimplePrompt` client to compress the outline of the `TextNode` and /// updates the `TextNode` with the compressed outline. /// /// # Arguments /// /// * `node` - The `TextNode` to be transformed. /// /// # Returns /// /// A result containing the transformed `TextNode` or an error if the transformation fails. /// /// # Errors /// /// This function will return an error if the `SimplePrompt` client fails to generate a /// response. #[tracing::instrument(skip_all, name = "transformers.compress_code_outline")] async fn transform_node(&self, mut node: TextNode) -> Result<TextNode> { if node.metadata.get(NAME).is_none() { return Ok(node); } let prompt = self.prompt_template.clone().with_node(&node); let response = extract_markdown_codeblock(self.prompt(prompt).await?); node.metadata.insert(NAME, response); Ok(node) } fn concurrency(&self) -> Option<usize> { self.concurrency } } #[cfg(test)] mod test { use swiftide_core::MockSimplePrompt; use super::*; #[test_log::test(tokio::test)] async fn test_compress_code_template() { let template = default_prompt(); let outline = "Relevant Outline"; let code = "Code using outline"; let mut node = TextNode::new(code); node.metadata.insert("Outline", outline); let prompt = template.clone().with_node(&node); insta::assert_snapshot!(prompt.render().unwrap()); } #[tokio::test] async fn test_compress_code_outline() { let mut client = MockSimplePrompt::new(); client .expect_prompt() .returning(|_| Ok("RelevantOutline".to_string())); let transformer = CompressCodeOutline::builder() .client(client) .build() .unwrap(); let mut node = TextNode::new("Some text"); node.offset = 0; node.original_size = 100; node.metadata .insert("Outline".to_string(), "Some outline".to_string()); let result = transformer.transform_node(node).await.unwrap(); assert_eq!(result.chunk, "Some text"); assert_eq!(result.metadata.get("Outline").unwrap(), "RelevantOutline"); } } ================================================ FILE: swiftide-integrations/src/treesitter/metadata_qa_code.rs ================================================ //! Generate questions and answers based on code chunks and add them as metadata use anyhow::Result; use async_trait::async_trait; use swiftide_core::{Transformer, indexing::TextNode}; /// `MetadataQACode` is responsible for generating questions and answers based on code chunks. /// This struct integrates with the indexing pipeline to enhance the metadata of each code chunk /// by adding relevant questions and answers. #[swiftide_macros::indexing_transformer( metadata_field_name = "Questions and Answers (code)", default_prompt_file = "prompts/metadata_qa_code.prompt.md" )] pub struct MetadataQACode { #[builder(default = "5")] num_questions: usize, } #[async_trait] impl Transformer for MetadataQACode { type Input = String; type Output = String; /// Asynchronously transforms a `TextNode` by generating questions and answers for its code /// chunk. /// /// This method uses the `SimplePrompt` client to generate questions and answers based on the /// code chunk and adds this information to the node's metadata. /// /// # Arguments /// /// * `node` - The `TextNode` to be transformed. /// /// # Returns /// /// A result containing the transformed `TextNode` or an error if the transformation fails. /// /// # Errors /// /// This function will return an error if the `SimplePrompt` client fails to generate a /// response. #[tracing::instrument(skip_all, name = "transformers.metadata_qa_code")] async fn transform_node(&self, mut node: TextNode) -> Result<TextNode> { let mut prompt = self .prompt_template .clone() .with_node(&node) .with_context_value("questions", self.num_questions); if let Some(outline) = node.metadata.get("Outline") { prompt = prompt.with_context_value("outline", outline.as_str()); } let response = self.prompt(prompt).await?; node.metadata.insert(NAME, response); Ok(node) } fn concurrency(&self) -> Option<usize> { self.concurrency } } #[cfg(test)] mod test { use swiftide_core::{MockSimplePrompt, assert_default_prompt_snapshot}; use super::*; assert_default_prompt_snapshot!("test", "questions" => 5); #[tokio::test] async fn test_template_with_outline() { let template = default_prompt(); let prompt = template .clone() .with_node(&TextNode::new("test")) .with_context_value("questions", 5) .with_context_value("outline", "Test outline"); insta::assert_snapshot!(prompt.render().unwrap()); } #[tokio::test] async fn test_metadata_qacode() { let mut client = MockSimplePrompt::new(); client .expect_prompt() .returning(|_| Ok("Q1: Hello\nA1: World".to_string())); let transformer = MetadataQACode::builder().client(client).build().unwrap(); let node = TextNode::new("Some text"); let result = transformer.transform_node(node).await.unwrap(); assert_eq!( result.metadata.get("Questions and Answers (code)").unwrap(), "Q1: Hello\nA1: World" ); } } ================================================ FILE: swiftide-integrations/src/treesitter/metadata_refs_defs_code.rs ================================================ //! Adds references and definitions found in code as metadata to chunks //! //! Uses tree-sitter to do the extractions. It tries to only get unique definitions and references, //! and only references that are not local. //! //! See the [`crate::treesitter::CodeParser`] tests for some examples. //! //! # Example //! //! ```no_run //! # use swiftide_core::indexing::TextNode; //! # use swiftide_integrations::treesitter::transformers::metadata_refs_defs_code::*; //! # use swiftide_core::Transformer; //! # #[tokio::main] //! # async fn main() -> Result<(), Box<dyn std::error::Error>> { //! let transformer = MetadataRefsDefsCode::try_from_language("rust").unwrap(); //! let code = r#" //! fn main() { //! println!("Hello, World!"); //! } //! "#; //! let mut node = TextNode::new(code.to_string()); //! //! node = transformer.transform_node(node).await.unwrap(); //! //! assert_eq!( //! node.metadata.get(NAME_REFERENCES).unwrap().as_str().unwrap(), //! "println" //! ); //! assert_eq!( //! node.metadata.get(NAME_DEFINITIONS).unwrap().as_str().unwrap(), //! "main" //! ); //! # Ok(()) //! # } //! ``` use std::sync::Arc; use swiftide_core::{Transformer, indexing::TextNode}; use crate::treesitter::{CodeParser, SupportedLanguages}; use anyhow::{Context as _, Result}; use async_trait::async_trait; pub const NAME_REFERENCES: &str = "References (code)"; pub const NAME_DEFINITIONS: &str = "Definitions (code)"; /// `MetadataRefsDefsCode` is responsible for extracting references and definitions. #[swiftide_macros::indexing_transformer(derive(skip_default))] pub struct MetadataRefsDefsCode { code_parser: Arc<CodeParser>, } impl MetadataRefsDefsCode { /// Tries to build a new `MetadataRefsDefsCode` transformer /// /// # Errors /// /// Language is not supported by tree-sitter pub fn try_from_language(language: impl TryInto<SupportedLanguages>) -> Result<Self> { let language: SupportedLanguages = language .try_into() .ok() .context("Treesitter language not supported")?; MetadataRefsDefsCode::builder() .code_parser(CodeParser::from_language(language)) .build() } } #[async_trait] impl Transformer for MetadataRefsDefsCode { type Input = String; type Output = String; /// Extracts references and definitions from code and /// adds them as metadata to the node if present async fn transform_node(&self, mut node: TextNode) -> Result<TextNode> { let refs_defs = self .code_parser .parse(&node.chunk)? .references_and_definitions()?; if !refs_defs.references.is_empty() { node.metadata .insert(NAME_REFERENCES.to_string(), refs_defs.references.join(",")); } if !refs_defs.definitions.is_empty() { node.metadata.insert( NAME_DEFINITIONS.to_string(), refs_defs.definitions.join(","), ); } Ok(node) } } #[cfg(test)] mod test { use super::*; use test_case::test_case; #[test_case("rust", "fn main() { println!(\"Hello, World!\"); }", "println", "main"; "rust")] #[test_case("ruby", "def main; puts 'Hello, World!'; end", "puts", "main"; "ruby")] #[test_case("python", "def main(): print('Hello, World!')", "print", "main"; "python")] #[test_case("javascript", "function main() { console.log('Hello, World!'); }", "log", "main"; "javascript")] #[test_case("typescript", "function main() { console.log('Hello, World!'); }", "log", "main"; "typescript")] #[test_case("java", "public class Main { public static void main(String[] args) { System.out.println(\"Hello, World!\"); } }", "println", "Main,main"; "java")] #[test_case("c-sharp", "public class Program { public static void Main(string[] args) { System.Console.WriteLine(\"Hello, World!\"); } }", "WriteLine", "Main,Program"; "c-sharp")] #[tokio::test] async fn assert_refs_defs_from_code( lang: &str, code: &str, expected_references: &str, expected_definitions: &str, ) { let transformer = MetadataRefsDefsCode::try_from_language(lang).unwrap(); let node = TextNode::new(code); let node = transformer.transform_node(node).await.unwrap(); let references = node .metadata .get(NAME_REFERENCES) .unwrap() .as_str() .unwrap() .to_string(); let definitions = node .metadata .get(NAME_DEFINITIONS) .unwrap() .as_str() .unwrap() .to_string(); assert_eq!(references, expected_references); assert_eq!(definitions, expected_definitions); } } ================================================ FILE: swiftide-integrations/src/treesitter/mod.rs ================================================ //! Chunking code with tree-sitter and various tools mod code_tree; mod outliner; mod queries; mod splitter; mod supported_languages; pub use code_tree::{CodeParser, CodeTree, ReferencesAndDefinitions}; pub use outliner::{CodeOutliner, CodeOutlinerBuilder}; pub use splitter::{ChunkSize, CodeSplitter, CodeSplitterBuilder}; pub use supported_languages::SupportedLanguages; pub mod chunk_code; pub mod compress_code_outline; pub mod metadata_qa_code; pub mod metadata_refs_defs_code; pub mod outline_code_tree_sitter; pub mod transformers { pub use super::chunk_code::{self, ChunkCode}; pub use super::compress_code_outline::{self, CompressCodeOutline}; pub use super::metadata_qa_code::{self, MetadataQACode}; pub use super::metadata_refs_defs_code::{self, MetadataRefsDefsCode}; pub use super::outline_code_tree_sitter::{self, OutlineCodeTreeSitter}; } ================================================ FILE: swiftide-integrations/src/treesitter/outline_code_tree_sitter.rs ================================================ //! Add the outline of the code in the given file to the metadata of a node, using tree-sitter. use anyhow::Result; use async_trait::async_trait; use swiftide_core::Transformer; use swiftide_core::indexing::TextNode; use crate::treesitter::{CodeOutliner, SupportedLanguages}; /// `OutlineCodeTreeSitter` adds a "Outline" field to the metadata of a node that contains /// a summary of the code in the node. It uses the tree-sitter parser to parse the code and /// remove any information that is less relevant for tasks that consider the file as a whole. #[swiftide_macros::indexing_transformer(metadata_field_name = "Outline", derive(skip_default))] pub struct OutlineCodeTreeSitter { outliner: CodeOutliner, minimum_file_size: Option<usize>, } impl OutlineCodeTreeSitter { /// Tries to create a `OutlineCodeTreeSitter` instance for a given programming language. /// /// # Parameters /// - `lang`: The programming language to be used to parse the code. It should implement /// `TryInto<SupportedLanguages>`. /// /// # Returns /// - `Result<Self>`: Returns an instance of `OutlineCodeTreeSitter` if successful, otherwise /// returns an error. /// /// # Errors /// - Returns an error if the language is not supported or if the `CodeOutliner` fails to build. pub fn try_for_language( lang: impl TryInto<SupportedLanguages>, minimum_file_size: Option<usize>, ) -> Result<Self> { Ok(Self { outliner: CodeOutliner::builder().try_language(lang)?.build()?, minimum_file_size, client: None, concurrency: None, indexing_defaults: None, }) } } #[async_trait] impl Transformer for OutlineCodeTreeSitter { type Input = String; type Output = String; /// Adds context to the metadata of a `TextNode` containing code in the "Outline" field. /// /// It uses the `CodeOutliner` to generate the context. /// /// # Parameters /// - `node`: The `TextNode` containing the code of which the context is to be generated. /// /// # Returns /// - `TextNode`: The same `TextNode` instances, with the metadata updated to include the /// generated context. /// /// # Errors /// - If the code outlining fails, an error is sent downstream. #[tracing::instrument(skip_all, name = "transformers.outline_code_tree_sitter")] async fn transform_node(&self, mut node: TextNode) -> Result<TextNode> { if let Some(minimum_file_size) = self.minimum_file_size && node.chunk.len() < minimum_file_size { return Ok(node); } let outline_result = self.outliner.outline(&node.chunk)?; node.metadata.insert(NAME, outline_result); Ok(node) } } ================================================ FILE: swiftide-integrations/src/treesitter/outliner.rs ================================================ use anyhow::{Context as _, Result}; use tree_sitter::{Node, Parser, TreeCursor}; use derive_builder::Builder; use super::supported_languages::SupportedLanguages; #[derive(Debug, Builder, Clone)] /// Generates a summary of a code file. /// /// It does so by parsing the code file and removing function bodies, leaving only the function /// signatures and other top-level declarations along with any comments. /// /// The resulting summary can be used as a context when considering subsets of the code file, or for /// determining relevance of the code file to a given task. #[builder(setter(into), build_fn(error = "anyhow::Error"))] pub struct CodeOutliner { #[builder(setter(custom))] language: SupportedLanguages, } impl CodeOutlinerBuilder { /// Attempts to set the language for the `CodeOutliner`. /// /// # Arguments /// /// * `language` - A value that can be converted into `SupportedLanguages`. /// /// # Returns /// /// * `Result<Self>` - The builder instance with the language set, or an error if the language /// is not supported. /// /// # Errors /// * If the language is not supported, an error is returned. pub fn try_language(mut self, language: impl TryInto<SupportedLanguages>) -> Result<Self> { self.language = Some( language .try_into() .ok() .context("Treesitter language not supported")?, ); Ok(self) } } impl CodeOutliner { /// Creates a new `CodeOutliner` with the specified language /// /// # Arguments /// /// * `language` - The programming language for which the code will be outlined. /// /// # Returns /// /// * `Self` - A new instance of `CodeOutliner`. pub fn new(language: SupportedLanguages) -> Self { Self { language } } /// Creates a new builder for `CodeOutliner`. /// /// # Returns /// /// * `CodeOutlinerBuilder` - A new builder instance for `CodeOutliner`. pub fn builder() -> CodeOutlinerBuilder { CodeOutlinerBuilder::default() } /// outlines a code file. /// /// # Arguments /// /// * `code` - The source code to be split. /// /// # Returns /// /// * `Result<String>` - A result containing a string, or an error if the code could not be /// parsed. /// /// # Errors /// * If the code could not be parsed, an error is returned. pub fn outline(&self, code: &str) -> Result<String> { let mut parser = Parser::new(); parser.set_language(&self.language.into())?; let tree = parser.parse(code, None).context("No nodes found")?; let root_node = tree.root_node(); if root_node.has_error() { anyhow::bail!("Root node has invalid syntax"); } let mut cursor = root_node.walk(); let mut summary = String::with_capacity(code.len()); let mut last_end = 0; self.outline_node(&mut cursor, code, &mut summary, &mut last_end); Ok(summary) } fn is_unneeded_node(&self, node: Node) -> bool { match self.language { SupportedLanguages::Rust | SupportedLanguages::Java | SupportedLanguages::CSharp => { matches!(node.kind(), "block") } SupportedLanguages::Typescript | SupportedLanguages::Javascript => { matches!(node.kind(), "statement_block") } SupportedLanguages::Python => match node.kind() { "block" => { let parent = node.parent().expect("Python block node has no parent"); parent.kind() == "function_definition" } _ => false, }, SupportedLanguages::Ruby => match node.kind() { "body_statement" => { let parent = node .parent() .expect("Ruby body_statement node has no parent"); parent.kind() == "method" } _ => false, }, SupportedLanguages::Go => unimplemented!(), SupportedLanguages::Solidity => unimplemented!(), SupportedLanguages::C => unimplemented!(), SupportedLanguages::Cpp => unimplemented!(), SupportedLanguages::Elixir => unimplemented!(), SupportedLanguages::HTML => unimplemented!(), SupportedLanguages::PHP => unimplemented!(), } } /// outlines a syntax node /// /// # Arguments /// /// * `node` - The syntax node to be chunked. /// * `source` - The source code as a string. /// * `last_end` - The end byte of the last chunk. /// /// # Returns /// /// * `String` - A summary of the syntax node. fn outline_node( &self, cursor: &mut TreeCursor, source: &str, summary: &mut String, last_end: &mut usize, ) { let node = cursor.node(); // If the node is not needed in the summary, skip it and go to the next sibling if self.is_unneeded_node(node) { summary.push_str(&source[*last_end..node.start_byte()]); *last_end = node.end_byte(); if cursor.goto_next_sibling() { self.outline_node(cursor, source, summary, last_end); } return; } let mut next_cursor = cursor.clone(); // If the node is a non-leaf, recursively outline its children if next_cursor.goto_first_child() { self.outline_node(&mut next_cursor, source, summary, last_end); // If the node is a leaf, add the text to the summary } else { summary.push_str(&source[*last_end..node.end_byte()]); *last_end = node.end_byte(); } if cursor.goto_next_sibling() { self.outline_node(cursor, source, summary, last_end); } else { // Done with this node } } } #[cfg(test)] mod tests { use super::*; // Test every supported language. // We should strip away all code blocks and leave only imports, comments, function signatures, // class, interface and structure definitions and definitions of constants, variables and other // members. #[test] fn test_outline_rust() { let code = r#" use anyhow::{Context as _, Result}; // This is a comment fn main(a: usize, b: usize) -> usize { println!("Hello, world!"); } pub struct Bla { a: usize } impl Bla { fn ok(&mut self) { self.a = 1; } }"#; let outliner = CodeOutliner::new(SupportedLanguages::Rust); let summary = outliner.outline(code).unwrap(); assert_eq!( summary, "\nuse anyhow::{Context as _, Result};\n// This is a comment\nfn main(a: usize, b: usize) -> usize \n\npub struct Bla {\n a: usize\n}\n\nimpl Bla {\n fn ok(&mut self) \n}" ); } #[test] fn test_outline_typescript() { let code = r#" import { Context as _, Result } from 'anyhow'; // This is a comment function main(a: number, b: number): number { console.log("Hello, world!"); } export class Bla { a: number; } export interface Bla { ok(): void; }"#; let outliner = CodeOutliner::new(SupportedLanguages::Typescript); let summary = outliner.outline(code).unwrap(); assert_eq!( summary, "\nimport { Context as _, Result } from 'anyhow';\n// This is a comment\nfunction main(a: number, b: number): number \n\nexport class Bla {\n a: number;\n}\n\nexport interface Bla {\n ok(): void;\n}" ); } #[test] fn test_outline_python() { let code = r#" import sys # This is a comment def main(a: int, b: int) -> int: print("Hello, world!") class Bla: def __init__(self): self.a = 1 def ok(self): self.a = 1 "#; let outliner = CodeOutliner::new(SupportedLanguages::Python); let summary = outliner.outline(code).unwrap(); assert_eq!( summary, "\nimport sys\n# This is a comment\ndef main(a: int, b: int) -> int:\n \n\nclass Bla:\n def __init__(self):\n \n\n def ok(self):\n " ); } #[test] fn test_outline_ruby() { let code = r#" require 'anyhow' # This is a comment def main(a, b) puts "Hello, world!" end class Bla def ok @a = 1 end end "#; let outliner = CodeOutliner::new(SupportedLanguages::Ruby); let summary = outliner.outline(code).unwrap(); assert_eq!( summary, "\nrequire 'anyhow'\n# This is a comment\ndef main(a, b)\n \nend\n\nclass Bla\n def ok\n \n end\nend" ); } #[test] fn test_outline_javascript() { let code = r#" import { Context as _, Result } from 'anyhow'; // This is a comment function main(a, b) { console.log("Hello, world!"); } class Bla { constructor() { this.a = 1; } ok() { this.a = 1; } } "#; let outliner = CodeOutliner::new(SupportedLanguages::Javascript); let summary = outliner.outline(code).unwrap(); assert_eq!( summary, "\nimport { Context as _, Result } from 'anyhow';\n// This is a comment\nfunction main(a, b) \n\nclass Bla {\n constructor() \n\n ok() \n}" ); } #[test] fn test_outline_java() { let code = r#" import java.io.PrintStream; import java.util.Scanner; public class HelloWorld { // This is a comment public static void main(String[] args) { PrintStream out = System.out; out.println("Hello, World!"); } } "#; let outliner = CodeOutliner::new(SupportedLanguages::Java); let summary = outliner.outline(code).unwrap(); println!("{summary}"); assert_eq!( summary, "\nimport java.io.PrintStream;\nimport java.util.Scanner;\n\npublic class HelloWorld {\n // This is a comment\n public static void main(String[] args) \n}" ); } } ================================================ FILE: swiftide-integrations/src/treesitter/prompts/compress_code_outline.prompt.md ================================================ # Filtering Code Outline Your task is to filter the given file outline to the code chunk provided. The goal is to provide a context that is still contains the lines needed for understanding the code in the chunk whilst leaving out any irrelevant information. ## Constraints - Only use lines from the provided context, do not add any additional information - Ensure that the selection you make is the most appropriate for the code chunk - Make sure you include any definitions or imports that are used in the code chunk - You do not need to repeat the code chunk in your response, it will be appended directly after your response. - Do not use lines that are present in the code chunk ## Code ``` {{ node.chunk }} ``` ## Outline ``` {{ node.metadata["Outline"] }} ``` ================================================ FILE: swiftide-integrations/src/treesitter/prompts/metadata_qa_code.prompt.md ================================================ # Task Your task is to generate questions and answers for the given code. Given that somebody else might ask questions about the code, consider things like: - What does this code do? - What other internal parts does the code use? - Does this code have any dependencies? - What are some potential use cases for this code? - ... and so on # Constraints - Generate only {{questions}} questions and answers. - Only respond in the example format - Only respond with questions and answers that can be derived from the code. # Example Respond in the following example format and do not include anything else: ``` Q1: What does this code do? A1: It transforms strings into integers. Q2: What other internal parts does the code use? A2: A hasher to hash the strings. ``` {% if outline %} ## Outline of the parent file ``` {{ outline }} ``` {% endif %} # Code ``` {{ node.chunk }} ``` ================================================ FILE: swiftide-integrations/src/treesitter/queries.rs ================================================ // https://github.com/tree-sitter/tree-sitter-ruby/blob/master/queries/tags.scm pub mod ruby { pub const DEFS: &str = r" ( [ (method name: (_) @name) (singleton_method name: (_) @name) ] ) (alias name: (_) @name) (setter (identifier) @ignore) ( [ (class name: [ (constant) @name (scope_resolution name: (_) @name) ]) (singleton_class value: [ (constant) @name (scope_resolution name: (_) @name) ]) ] ) ( (module name: [ (constant) @name (scope_resolution name: (_) @name) ]) ) "; pub const REFS: &str = r#" (call method: (identifier) @name) ( [(identifier) (constant)] @name (#is-not? local) (#not-match? @name "^(lambda|load|require|require_relative|__FILE__|__LINE__)$") ) "#; } // https://github.com/tree-sitter/tree-sitter-python/blob/master/queries/tags.scm pub mod python { pub const DEFS: &str = r#" (class_definition name: (identifier) @name) ( (function_definition name: (identifier) @name) (#not-eq? @name "__init__") ) "#; pub const REFS: &str = " (call function: [ (identifier) @name (attribute attribute: (identifier)) ]) "; } // https://github.com/tree-sitter/tree-sitter-typescript/blob/master/queries/tags.scm pub mod typescript { pub const DEFS: &str = r#" (function_signature name: (identifier) @name) (method_signature name: (property_identifier) @name) (abstract_method_signature name: (property_identifier) @name) (abstract_class_declaration name: (type_identifier) @name) (module name: (identifier) @name) (interface_declaration name: (type_identifier) @name) ( (method_definition name: (property_identifier) @name) (#not-eq? @name "constructor") ) ( [ (class name: (_) @name) (class_declaration name: (_) @name) ] ) ( [ (function_expression name: (identifier) @name) (function_declaration name: (identifier) @name) (generator_function name: (identifier) @name) (generator_function_declaration name: (identifier) @name) ] ) ( (lexical_declaration (variable_declarator name: (identifier) @name value: [(arrow_function) (function_expression)])) ) ( (variable_declaration (variable_declarator name: (identifier) @name value: [(arrow_function) (function_expression)])) ) "#; pub const REFS: &str = r#" (type_annotation (type_identifier) @name) (new_expression constructor: (identifier) @name) ( (call_expression function: (identifier) @name) (#not-match? @name "^(require)$") ) (call_expression function: (member_expression property: (property_identifier) @name) arguments: (_)) "#; } // https://github.com/tree-sitter/tree-sitter-javascript/blob/master/queries/tags.scm pub mod javascript { pub const DEFS: &str = r#" ( (method_definition name: (property_identifier) @name) (#not-eq? @name "constructor") ) ( [ (class name: (_) @name) (class_declaration name: (_) @name) ] ) ( [ (function_expression name: (identifier) @name) (function_declaration name: (identifier) @name) (generator_function name: (identifier) @name) (generator_function_declaration name: (identifier) @name) ] ) ( (lexical_declaration (variable_declarator name: (identifier) @name value: [(arrow_function) (function_expression)]) @definition.function) ) ( (variable_declaration (variable_declarator name: (identifier) @name value: [(arrow_function) (function_expression)]) @definition.function) ) (assignment_expression left: [ (identifier) @name (member_expression property: (property_identifier) @name) ] right: [(arrow_function) (function_expression)] ) (pair key: (property_identifier) @name value: [(arrow_function) (function_expression)]) "#; pub const REFS: &str = r#" ( (call_expression function: (identifier) @name) (#not-match? @name "^(require)$") ) (call_expression function: (member_expression property: (property_identifier) @name) arguments: (_)) (new_expression constructor: (_) @name) (export_statement value: (assignment_expression left: (identifier) @name right: ([ (number) (string) (identifier) (undefined) (null) (new_expression) (binary_expression) (call_expression) ]))) "#; } // https://github.com/tree-sitter/tree-sitter-rust/blob/master/queries/tags.scm pub mod rust { pub const DEFS: &str = " (struct_item name: (type_identifier) @name) (enum_item name: (type_identifier) @name) (union_item name: (type_identifier) @name) (type_item name: (type_identifier) @name) (declaration_list (function_item name: (identifier) @name)) (function_item name: (identifier) @name) (trait_item name: (type_identifier) @name) (mod_item name: (identifier) @name) (macro_definition name: (identifier) @name) "; pub const REFS: &str = " (call_expression function: (identifier) @name) (call_expression function: (field_expression field: (field_identifier) @name)) (macro_invocation macro: (identifier) @name) "; } // https://github.com/tree-sitter/tree-sitter-java/blob/master/queries/tags.scm pub mod java { pub const DEFS: &str = " (class_declaration name: (identifier) @name) (enum_declaration name: (identifier) @name) (method_declaration name: (identifier) @name) (interface_declaration name: (identifier) @name) (type_list (type_identifier) @name) (superclass (type_identifier) @name)"; pub const REFS: &str = " (method_invocation name: (identifier) @name arguments: (argument_list)) (object_creation_expression type: (type_identifier) @name)"; } pub mod go { pub const DEFS: &str = r" (function_declaration name: (identifier) @name) (method_declaration name: (field_identifier) @name) (type_declaration (type_spec name: (type_identifier) @name type: (interface_type))) (type_declaration (type_spec name: (type_identifier) @name type: (struct_type))) (import_declaration (import_spec) @name) (var_declaration (var_spec name: (identifier) @name)) (const_declaration (const_spec name: (identifier) @name)) "; pub const REFS: &str = r#" (call_expression function: [ (identifier) @name (parenthesized_expression (identifier) @name) (selector_expression field: (field_identifier) @name) (parenthesized_expression (selector_expression field: (field_identifier) @name)) ]) (type_spec name: (type_identifier) @name) (package_clause "package" (package_identifier) @name) (type_identifier) @name "#; } pub mod solidity { pub const DEFS: &str = r" (function_definition name: (identifier) @name) (source_file (function_definition name: (identifier) @name)) (contract_declaration name: (identifier) @name) (interface_declaration name: (identifier) @name) (library_declaration name: (identifier) @name) (struct_declaration name: (identifier) @name) (enum_declaration name: (identifier) @name) (event_definition name: (identifier) @name) "; pub const REFS: &str = r" (call_expression (expression (identifier)) @name ) (call_expression (expression (member_expression property: (_) @name ))) (emit_statement name: (_) @name) (inheritance_specifier ancestor: (user_defined_type (_) @name . )) (import_directive import_name: (_) @name ) "; } // https://github.com/tree-sitter/tree-sitter-c-sharp/blob/master/queries/tags.scm pub mod csharp { pub const DEFS: &str = r" (class_declaration name: (identifier) @name) (interface_declaration name: (identifier) @name) (method_declaration name: (identifier) @name) (namespace_declaration name: (identifier) @name) "; pub const REFS: &str = r" (class_declaration (base_list (_) @name)) (interface_declaration (base_list (_) @name)) (object_creation_expression type: (identifier) @name) (type_parameter_constraints_clause (identifier) @name) (type_parameter_constraint (type type: (identifier) @name)) (variable_declaration type: (identifier) @name) (invocation_expression function: (member_access_expression name: (identifier) @name)) "; } ================================================ FILE: swiftide-integrations/src/treesitter/snapshots/swiftide_integrations__treesitter__compress_code_outline__test__compress_code_template.snap ================================================ --- source: swiftide-integrations/src/treesitter/compress_code_outline.rs expression: prompt.render().await.unwrap() --- # Filtering Code Outline Your task is to filter the given file outline to the code chunk provided. The goal is to provide a context that is still contains the lines needed for understanding the code in the chunk whilst leaving out any irrelevant information. ## Constraints - Only use lines from the provided context, do not add any additional information - Ensure that the selection you make is the most appropriate for the code chunk - Make sure you include any definitions or imports that are used in the code chunk - You do not need to repeat the code chunk in your response, it will be appended directly after your response. - Do not use lines that are present in the code chunk ## Code ``` Code using outline ``` ## Outline ``` Relevant Outline ``` ================================================ FILE: swiftide-integrations/src/treesitter/snapshots/swiftide_integrations__treesitter__metadata_qa_code__test__default_prompt.snap ================================================ --- source: swiftide-integrations/src/treesitter/metadata_qa_code.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate questions and answers for the given code. Given that somebody else might ask questions about the code, consider things like: - What does this code do? - What other internal parts does the code use? - Does this code have any dependencies? - What are some potential use cases for this code? - ... and so on # Constraints - Generate only 5 questions and answers. - Only respond in the example format - Only respond with questions and answers that can be derived from the code. # Example Respond in the following example format and do not include anything else: ``` Q1: What does this code do? A1: It transforms strings into integers. Q2: What other internal parts does the code use? A2: A hasher to hash the strings. ``` # Code ``` test ``` ================================================ FILE: swiftide-integrations/src/treesitter/snapshots/swiftide_integrations__treesitter__metadata_qa_code__test__template_with_outline.snap ================================================ --- source: swiftide-integrations/src/treesitter/metadata_qa_code.rs expression: prompt.render().await.unwrap() --- # Task Your task is to generate questions and answers for the given code. Given that somebody else might ask questions about the code, consider things like: - What does this code do? - What other internal parts does the code use? - Does this code have any dependencies? - What are some potential use cases for this code? - ... and so on # Constraints - Generate only 5 questions and answers. - Only respond in the example format - Only respond with questions and answers that can be derived from the code. # Example Respond in the following example format and do not include anything else: ``` Q1: What does this code do? A1: It transforms strings into integers. Q2: What other internal parts does the code use? A2: A hasher to hash the strings. ``` ## Outline of the parent file ``` Test outline ``` # Code ``` test ``` ================================================ FILE: swiftide-integrations/src/treesitter/splitter.rs ================================================ use anyhow::{Context as _, Result}; use std::ops::Range; use tree_sitter::{Node, Parser}; use derive_builder::Builder; use super::supported_languages::SupportedLanguages; // TODO: Instead of counting bytes, count tokens with titktoken const DEFAULT_MAX_BYTES: usize = 1500; #[derive(Debug, Builder, Clone)] /// Splits code files into meaningful chunks /// /// Supports splitting code files into chunks based on a maximum size or a range of bytes. #[builder(setter(into), build_fn(error = "anyhow::Error"))] pub struct CodeSplitter { /// Maximum size of a chunk in bytes or a range of bytes #[builder(default, setter(into))] chunk_size: ChunkSize, #[builder(setter(custom))] language: SupportedLanguages, } impl CodeSplitterBuilder { /// Attempts to set the language for the `CodeSplitter`. /// /// # Arguments /// /// * `language` - A value that can be converted into `SupportedLanguages`. /// /// # Returns /// /// * `Result<Self>` - The builder instance with the language set, or an error if the language /// is not supported. /// /// # Errors /// /// Errors if language is not supported pub fn try_language(mut self, language: impl TryInto<SupportedLanguages>) -> Result<Self> { self.language = Some( language .try_into() .ok() .context("Treesitter language not supported")?, ); Ok(self) } } #[derive(Debug, Clone)] /// Represents the size of a chunk, either as a fixed number of bytes or a range of bytes. pub enum ChunkSize { Bytes(usize), Range(Range<usize>), } impl From<usize> for ChunkSize { /// Converts a `usize` into a `ChunkSize::Bytes` variant. fn from(size: usize) -> Self { ChunkSize::Bytes(size) } } impl From<Range<usize>> for ChunkSize { /// Converts a `Range<usize>` into a `ChunkSize::Range` variant. fn from(range: Range<usize>) -> Self { ChunkSize::Range(range) } } impl Default for ChunkSize { /// Provides a default value for `ChunkSize`, which is `ChunkSize::Bytes(DEFAULT_MAX_BYTES)`. fn default() -> Self { ChunkSize::Bytes(DEFAULT_MAX_BYTES) } } impl CodeSplitter { /// Creates a new `CodeSplitter` with the specified language and default chunk size. /// /// # Arguments /// /// * `language` - The programming language for which the code will be split. /// /// # Returns /// /// * `Self` - A new instance of `CodeSplitter`. pub fn new(language: SupportedLanguages) -> Self { Self { chunk_size: ChunkSize::default(), language, } } /// Creates a new builder for `CodeSplitter`. /// /// # Returns /// /// * `CodeSplitterBuilder` - A new builder instance for `CodeSplitter`. pub fn builder() -> CodeSplitterBuilder { CodeSplitterBuilder::default() } /// Recursively chunks a syntax node into smaller pieces based on the chunk size. /// /// # Arguments /// /// * `node` - The syntax node to be chunked. /// * `source` - The source code as a string. /// * `last_end` - The end byte of the last chunk. /// /// # Returns /// /// * `Vec<String>` - A vector of code chunks as strings. fn chunk_node( &self, node: Node, source: &str, mut last_end: usize, current_chunk: Option<String>, ) -> Vec<String> { let mut new_chunks: Vec<String> = Vec::new(); let mut current_chunk = current_chunk.unwrap_or_default(); for child in node.children(&mut node.walk()) { debug_assert!( current_chunk.len() <= self.max_bytes(), "Chunk too big: {} > {}", current_chunk.len(), self.max_bytes() ); // if the next child will make the chunk too big then there are two options: // 1. if the next child is too big to fit in a whole chunk, then recursively chunk it // one level down // 2. if the next child is small enough to fit in a chunk, then add the current chunk to // the list and start a new chunk let next_child_size = child.end_byte() - last_end; if current_chunk.len() + next_child_size >= self.max_bytes() { if next_child_size > self.max_bytes() { let mut sub_chunks = self.chunk_node(child, source, last_end, Some(current_chunk)); current_chunk = sub_chunks.pop().unwrap_or_default(); new_chunks.extend(sub_chunks); } else { // NOTE: if the current chunk was smaller than then the min_bytes, then it is // discarded here if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() { new_chunks.push(current_chunk); } current_chunk = source[last_end..child.end_byte()].to_string(); } } else { current_chunk += &source[last_end..child.end_byte()]; } last_end = child.end_byte(); } if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() { new_chunks.push(current_chunk); } new_chunks } /// Splits the given code into chunks based on the chunk size. /// /// # Arguments /// /// * `code` - The source code to be split. /// /// # Returns /// /// * `Result<Vec<String>>` - A result containing a vector of code chunks as strings, or an /// error if the code could not be parsed. /// /// # Errors /// /// Returns an error if the node cannot be found or fails to parse pub fn split(&self, code: &str) -> Result<Vec<String>> { let mut parser = Parser::new(); parser.set_language(&self.language.into())?; let tree = parser.parse(code, None).context("No nodes found")?; let root_node = tree.root_node(); if root_node.has_error() { tracing::warn!("Syntax error parsing code: {code:?}"); return Ok(vec![code.to_string()]); } Ok(self.chunk_node(root_node, code, 0, None)) } /// Returns the maximum number of bytes allowed in a chunk. /// /// # Returns /// /// * `usize` - The maximum number of bytes in a chunk. fn max_bytes(&self) -> usize { match &self.chunk_size { ChunkSize::Bytes(size) => *size, ChunkSize::Range(range) => range.end, } } /// Returns the minimum number of bytes allowed in a chunk. /// /// # Returns /// /// * `usize` - The minimum number of bytes in a chunk. fn min_bytes(&self) -> usize { if let ChunkSize::Range(range) = &self.chunk_size { range.start } else { 0 } } } #[cfg(test)] mod test { use super::*; use indoc::indoc; #[test] fn test_split_single_chunk() { let code = "fn hello_world() {}"; let splitter = CodeSplitter::new(SupportedLanguages::Rust); let chunks = splitter.split(code); assert_eq!(chunks.unwrap(), vec!["fn hello_world() {}"]); } #[test] fn test_chunk_lines() { let splitter = CodeSplitter::new(SupportedLanguages::Rust); let text = indoc! {r#" fn main() { println!("Hello"); println!("World"); println!("!"); } "#}; let chunks = splitter.split(text).unwrap(); dbg!(&chunks); assert_eq!(chunks.len(), 1); assert_eq!( chunks[0], "fn main() {\n println!(\"Hello\");\n println!(\"World\");\n println!(\"!\");\n}" ); } #[test] fn test_max_bytes_limit() { let splitter = CodeSplitter::builder() .try_language(SupportedLanguages::Rust) .unwrap() .chunk_size(50) .build() .unwrap(); let text = indoc! {r#" fn main() { println!("Hello, World!"); println!("Goodbye, World!"); } "#}; let chunks = splitter.split(text).unwrap(); assert!(chunks.iter().all(|chunk| chunk.len() <= 50)); assert!( chunks .windows(2) .all(|pair| pair.iter().map(String::len).sum::<usize>() >= 50) ); assert_eq!( chunks, vec![ "fn main() {\n println!(\"Hello, World!\");", "\n println!(\"Goodbye, World!\");\n}", ] ); } #[test] fn test_empty_text() { let splitter = CodeSplitter::builder() .try_language(SupportedLanguages::Rust) .unwrap() .chunk_size(50) .build() .unwrap(); let text = ""; let chunks = splitter.split(text).unwrap(); dbg!(&chunks); assert_eq!(chunks.len(), 0); } #[test] fn test_range_max() { let splitter = CodeSplitter::builder() .try_language(SupportedLanguages::Rust) .unwrap() .chunk_size(0..50) .build() .unwrap(); let text = indoc! {r#" fn main() { println!("Hello, World!"); println!("Goodbye, World!"); } "#}; let chunks = splitter.split(text).unwrap(); assert_eq!( chunks, vec![ "fn main() {\n println!(\"Hello, World!\");", "\n println!(\"Goodbye, World!\");\n}", ] ); } #[test] fn test_range_min_and_max() { let splitter = CodeSplitter::builder() .try_language(SupportedLanguages::Rust) .unwrap() .chunk_size(20..50) .build() .unwrap(); let text = indoc! {r#" fn main() { println!("Hello, World!"); println!("Goodbye, World!"); } "#}; let chunks = splitter.split(text).unwrap(); assert!(chunks.iter().all(|chunk| chunk.len() <= 50)); assert!( chunks .windows(2) .all(|pair| pair.iter().map(String::len).sum::<usize>() > 50) ); assert!(chunks.iter().all(|chunk| chunk.len() >= 20)); assert_eq!( chunks, vec![ "fn main() {\n println!(\"Hello, World!\");", "\n println!(\"Goodbye, World!\");\n}" ] ); } #[test] fn test_on_self() { // read the current file let code = include_str!("splitter.rs"); // try chunking with varying ranges of bytes, give me ten with different min and max let ranges = vec![ 10..200, 50..100, 100..150, 150..200, 200..250, 250..300, 300..350, 350..400, 400..450, 450..500, ]; for range in ranges { let min = range.start; let max = range.end; let splitter = CodeSplitter::builder() .try_language("rust") .unwrap() .chunk_size(range) .build() .unwrap(); assert_eq!(splitter.min_bytes(), min); assert_eq!(splitter.max_bytes(), max); let chunks = splitter.split(code).unwrap(); assert!(chunks.iter().all(|chunk| chunk.len() <= max)); let chunk_pairs_that_are_smaller_than_max = chunks .windows(2) .filter(|pair| pair.iter().map(String::len).sum::<usize>() < max); assert!( chunk_pairs_that_are_smaller_than_max.clone().count() == 0, "max: {}, {} + {}, {:?}", max, chunk_pairs_that_are_smaller_than_max .clone() .next() .unwrap()[0] .len(), chunk_pairs_that_are_smaller_than_max .clone() .next() .unwrap()[1] .len(), chunk_pairs_that_are_smaller_than_max .collect::<Vec<_>>() .first() ); assert!(chunks.iter().all(|chunk| chunk.len() >= min)); assert!( chunks.iter().all(|chunk| chunk.len() >= min), "{:?}", chunks .iter() .filter(|chunk| chunk.len() < min) .collect::<Vec<_>>() ); assert!( chunks.iter().all(|chunk| chunk.len() <= max), "max = {}, chunks = {:?}", max, chunks .iter() .filter(|chunk| chunk.len() > max) .collect::<Vec<_>>() ); } // assert there are no nodes smaller than 10 } } ================================================ FILE: swiftide-integrations/src/treesitter/supported_languages.rs ================================================ //! This module defines the supported programming languages for the Swiftide project and provides //! utility functions for mapping these languages to their respective file extensions and //! tree-sitter language objects. //! //! The primary purpose of this module is to facilitate the recognition and handling of different //! programming languages by mapping file extensions and converting language enums to tree-sitter //! language objects for accurate parsing and syntax analysis. //! //! # Supported Languages //! - Rust //! - Typescript //! - Python //! - Ruby //! - Javascript //! - Solidity use std::hash::Hash; #[allow(unused_imports)] pub use std::str::FromStr as _; use serde::{Deserialize, Serialize}; /// Enum representing the supported programming languages in the Swiftide project. /// /// This enum is used to map programming languages to their respective file extensions and /// tree-sitter language objects. The `EnumString` and `Display` macros from the `strum_macros` /// crate are used to provide string conversion capabilities. The `ascii_case_insensitive` attribute /// allows for case-insensitive string matching. #[derive( Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize, strum_macros::EnumString, strum_macros::Display, strum_macros::EnumIter, strum_macros::AsRefStr, )] #[strum(ascii_case_insensitive)] #[non_exhaustive] pub enum SupportedLanguages { #[serde(alias = "rust")] Rust, #[serde(alias = "typescript")] Typescript, #[serde(alias = "python")] Python, #[serde(alias = "ruby")] Ruby, #[serde(alias = "javascript")] Javascript, #[serde(alias = "java")] Java, #[serde(alias = "go")] Go, #[serde(rename = "c-sharp", alias = "csharp", alias = "c#", alias = "C#")] #[strum( serialize = "csharp", serialize = "c-sharp", serialize = "c#", serialize = "C#", to_string = "c-sharp" )] CSharp, #[serde(alias = "solidity")] Solidity, #[serde(alias = "c")] C, #[serde(alias = "cpp", alias = "c++", alias = "C++", rename = "C++")] #[strum( serialize = "c++", serialize = "cpp", serialize = "Cpp", to_string = "C++" )] Cpp, #[serde(alias = "elixir")] Elixir, #[serde(alias = "html", alias = "Html")] HTML, #[serde(alias = "php", alias = "PHP", alias = "Php")] PHP, } impl Hash for SupportedLanguages { /// Implements the `Hash` trait for `SupportedLanguages`. /// /// This allows instances of `SupportedLanguages` to be used as keys in hash maps and sets. /// /// # Parameters /// - `state`: The mutable state to which the hash is added. fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.as_ref().hash(state); } } /// Static array of file extensions for Rust files. static RUST_EXTENSIONS: &[&str] = &["rs"]; /// Static array of file extensions for Typescript files. static TYPESCRIPT_EXTENSIONS: &[&str] = &["ts", "tsx", "js", "jsx"]; /// Static array of file extensions for Python files. static PYTHON_EXTENSIONS: &[&str] = &["py"]; /// Static array of file extensions for Ruby files. static RUBY_EXTENSIONS: &[&str] = &["rb"]; /// Static array of file extensions for Javascript files. static JAVASCRIPT_EXTENSIONS: &[&str] = &["js", "jsx"]; /// Static array of file extensions for Java files. static JAVA_EXTENSIONS: &[&str] = &["java"]; /// Static array of file extensions for Go files. static GO_EXTENSIONS: &[&str] = &["go"]; /// Static array of file extensions for C# files. static C_SHARP_EXTENSIONS: &[&str] = &["cs", "csx"]; /// Static array of file extensions for Solidity files. static SOLIDITY_EXTENSIONS: &[&str] = &["sol"]; /// Static array of file extensions for C files. static C_EXTENSIONS: &[&str] = &["c", "h", "o"]; /// Static array of file extensions for C++ files. static CPP_EXTENSIONS: &[&str] = &["c", "h", "o", "cc", "cpp"]; static ELIXIR_EXTENSIONS: &[&str] = &["ex", "exs"]; static HTML_EXTENSIONS: &[&str] = &["html", "htm", "xhtml"]; static PHP_EXTENSIONS: &[&str] = &["php"]; impl SupportedLanguages { /// Returns the file extensions associated with the supported language. /// /// # Returns /// A static slice of string slices representing the file extensions. pub fn file_extensions(&self) -> &[&str] { match self { SupportedLanguages::Rust => RUST_EXTENSIONS, SupportedLanguages::Typescript => TYPESCRIPT_EXTENSIONS, SupportedLanguages::Python => PYTHON_EXTENSIONS, SupportedLanguages::Ruby => RUBY_EXTENSIONS, SupportedLanguages::Javascript => JAVASCRIPT_EXTENSIONS, SupportedLanguages::Java => JAVA_EXTENSIONS, SupportedLanguages::Go => GO_EXTENSIONS, SupportedLanguages::CSharp => C_SHARP_EXTENSIONS, SupportedLanguages::Solidity => SOLIDITY_EXTENSIONS, SupportedLanguages::C => C_EXTENSIONS, SupportedLanguages::Cpp => CPP_EXTENSIONS, SupportedLanguages::Elixir => ELIXIR_EXTENSIONS, SupportedLanguages::HTML => HTML_EXTENSIONS, SupportedLanguages::PHP => PHP_EXTENSIONS, } } } impl From<SupportedLanguages> for tree_sitter::Language { /// Converts a `SupportedLanguages` enum to a `tree_sitter::Language` object. /// /// This implementation allows for the conversion of the supported languages to their respective /// tree-sitter language objects, enabling accurate parsing and syntax analysis. /// /// # Parameters /// - `val`: The `SupportedLanguages` enum value to be converted. /// /// # Returns /// A `tree_sitter::Language` object corresponding to the provided `SupportedLanguages` enum /// value. fn from(val: SupportedLanguages) -> Self { match val { SupportedLanguages::Rust => tree_sitter_rust::LANGUAGE, SupportedLanguages::Python => tree_sitter_python::LANGUAGE, SupportedLanguages::Typescript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT, SupportedLanguages::Javascript => tree_sitter_javascript::LANGUAGE, SupportedLanguages::Ruby => tree_sitter_ruby::LANGUAGE, SupportedLanguages::Java => tree_sitter_java::LANGUAGE, SupportedLanguages::Go => tree_sitter_go::LANGUAGE, SupportedLanguages::CSharp => tree_sitter_c_sharp::LANGUAGE, SupportedLanguages::Solidity => tree_sitter_solidity::LANGUAGE, SupportedLanguages::C => tree_sitter_c::LANGUAGE, SupportedLanguages::Cpp => tree_sitter_cpp::LANGUAGE, SupportedLanguages::Elixir => tree_sitter_elixir::LANGUAGE, SupportedLanguages::HTML => tree_sitter_html::LANGUAGE, SupportedLanguages::PHP => tree_sitter_php::LANGUAGE_PHP, } .into() } } #[cfg(test)] mod test { use super::*; pub use strum::IntoEnumIterator as _; /// Tests the case-insensitive string conversion for `SupportedLanguages`. #[test] fn test_supported_languages_from_str() { assert_eq!( SupportedLanguages::from_str("rust"), Ok(SupportedLanguages::Rust) ); assert_eq!( SupportedLanguages::from_str("typescript"), Ok(SupportedLanguages::Typescript) ); assert_eq!( SupportedLanguages::from_str("java"), Ok(SupportedLanguages::Java) ); assert_eq!( SupportedLanguages::from_str("c-sharp"), Ok(SupportedLanguages::CSharp) ); } /// Tests the case-insensitive string conversion for `SupportedLanguages` with different casing. #[test] fn test_supported_languages_from_str_case_insensitive() { assert_eq!( SupportedLanguages::from_str("Rust"), Ok(SupportedLanguages::Rust) ); assert_eq!( SupportedLanguages::from_str("TypeScript"), Ok(SupportedLanguages::Typescript) ); assert_eq!( SupportedLanguages::from_str("Java"), Ok(SupportedLanguages::Java) ); assert_eq!( SupportedLanguages::from_str("C-Sharp"), Ok(SupportedLanguages::CSharp) ); assert_eq!( SupportedLanguages::from_str("C++"), Ok(SupportedLanguages::Cpp) ); assert_eq!( SupportedLanguages::from_str("cpp"), Ok(SupportedLanguages::Cpp) ); assert_eq!( SupportedLanguages::from_str("elixir"), Ok(SupportedLanguages::Elixir) ); } #[test] fn test_serialize_and_deserialize_for_supported_languages() { for lang in SupportedLanguages::iter() { let val = serde_json::to_string(&lang).unwrap(); assert_eq!( serde_json::to_string(&lang).unwrap(), format!("\"{lang}\""), "Failed to serialize {lang}" ); assert_eq!( serde_json::from_str::<SupportedLanguages>(&val).unwrap(), lang, "Failed to deserialize {lang}" ); assert_eq!( serde_json::from_str::<SupportedLanguages>(&val.to_lowercase()).unwrap(), lang, "Failed to deserialize lowercase {lang}" ); } } } ================================================ FILE: swiftide-langfuse/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-langfuse" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] # TODO: Go over these serde.workspace = true serde_with = { version = "^3.8", default-features = false, features = [ "base64", "std", "macros", ] } serde_json.workspace = true serde_repr = "^0.1" chrono = { workspace = true, features = ["now"] } tokio.workspace = true uuid.workspace = true tracing.workspace = true tracing-subscriber.workspace = true futures = "^0.3" url = "^2.5" reqwest = { version = "^0.13", default-features = false, features = [ "json", "multipart", ] } anyhow.workspace = true async-trait.workspace = true dyn-clone.workspace = true swiftide-core = { path = "../swiftide-core", version = "0.32" } [dev-dependencies] wiremock.workspace = true test-log.workspace = true insta.workspace = true tracing-appender = "0.2.3" # We need some custom stuff because of the codegen and me being lazy [lints.rust] dead_code = "warn" [lints.clippy] cargo = { level = "warn", priority = -1 } # pedantic = { level = "warn", priority = -1 } blocks_in_conditions = "allow" must_use_candidate = "allow" module_name_repetitions = "allow" missing_fields_in_debug = "allow" multiple_crate_versions = "allow" option_option = "allow" ================================================ FILE: swiftide-langfuse/src/apis/configuration.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech #[derive(Debug, Clone)] pub struct Configuration { pub base_path: String, pub user_agent: Option<String>, pub client: reqwest::Client, pub basic_auth: Option<BasicAuth>, pub oauth_access_token: Option<String>, pub bearer_access_token: Option<String>, pub api_key: Option<ApiKey>, } pub type BasicAuth = (String, Option<String>); #[derive(Debug, Clone)] pub struct ApiKey { pub prefix: Option<String>, pub key: String, } impl Configuration { pub fn new() -> Configuration { Configuration::default() } } impl Default for Configuration { fn default() -> Self { Configuration { base_path: "http://localhost".to_owned(), user_agent: Some("OpenAPI-Generator//rust".to_owned()), client: reqwest::Client::new(), basic_auth: None, oauth_access_token: None, bearer_access_token: None, api_key: None, } } } ================================================ FILE: swiftide-langfuse/src/apis/ingestion_api.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use super::{ContentType, Error, configuration}; use crate::{apis::ResponseContent, models}; use reqwest; use serde::{Deserialize, Serialize, de::Error as _}; /// struct for typed errors of method [`ingestion_batch`] #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum IngestionBatchError { Status400(serde_json::Value), Status401(serde_json::Value), Status403(serde_json::Value), Status404(serde_json::Value), Status405(serde_json::Value), UnknownValue(serde_json::Value), } /// Batched ingestion for Langfuse Tracing. If you want to use tracing via the API, such as to build your own Langfuse client implementation, this is the only API route you need to implement. Within each batch, there can be multiple events. Each event has a type, an id, a timestamp, metadata and a body. Internally, we refer to this as the \"event envelope\" as it tells us something about the event but not the trace. We use the event id within this envelope to deduplicate messages to avoid processing the same event twice, i.e. the event id should be unique per request. The event.body.id is the ID of the actual trace and will be used for updates and will be visible within the Langfuse App. I.e. if you want to update a trace, you'd use the same body id, but separate event IDs. Notes: - Introduction to data model: <https://langfuse.com/docs/tracing-data-model> - Batch sizes are limited to 3.5 MB in total. You need to adjust the number of events per batch accordingly. - The API does not return a 4xx status code for input errors. Instead, it responds with a 207 status code, which includes a list of the encountered errors. pub async fn ingestion_batch( configuration: &configuration::Configuration, ingestion_batch_request: &models::IngestionBatchRequest, ) -> Result<models::IngestionResponse, Error<IngestionBatchError>> { // add a prefix to parameters to efficiently prevent name collisions let p_ingestion_batch_request = ingestion_batch_request; let uri_str = format!("{}/api/public/ingestion", configuration.base_path); let mut req_builder = configuration .client .request(reqwest::Method::POST, &uri_str); if let Some(ref user_agent) = configuration.user_agent { req_builder = req_builder.header(reqwest::header::USER_AGENT, user_agent.clone()); } if let Some(ref auth_conf) = configuration.basic_auth { req_builder = req_builder.basic_auth(auth_conf.0.clone(), auth_conf.1.clone()); } req_builder = req_builder.json(&p_ingestion_batch_request); let req = req_builder.build()?; let resp = configuration.client.execute(req).await?; let status = resp.status(); let content_type = resp .headers() .get("content-type") .and_then(|v| v.to_str().ok()) .unwrap_or("application/octet-stream"); let content_type = super::ContentType::from(content_type); if !status.is_client_error() && !status.is_server_error() { let content = resp.text().await?; match content_type { ContentType::Json => serde_json::from_str(&content).map_err(Error::from), ContentType::Text => Err(Error::from(serde_json::Error::custom( "Received `text/plain` content type response that cannot be converted to `models::IngestionResponse`", ))), ContentType::Unsupported(unknown_type) => { Err(Error::from(serde_json::Error::custom(format!( "Received `{unknown_type}` content type response that cannot be converted to `models::IngestionResponse`" )))) } } } else { let content = resp.text().await?; let entity: Option<IngestionBatchError> = serde_json::from_str(&content).ok(); Err(Error::ResponseError(ResponseContent { status, content, entity, })) } } ================================================ FILE: swiftide-langfuse/src/apis/mod.rs ================================================ use std::error; use std::fmt; #[derive(Debug, Clone)] #[allow(dead_code)] pub struct ResponseContent<T> { pub status: reqwest::StatusCode, pub content: String, pub entity: Option<T>, } #[derive(Debug)] pub enum Error<T> { Reqwest(reqwest::Error), Serde(serde_json::Error), Io(std::io::Error), #[allow(clippy::enum_variant_names)] ResponseError(ResponseContent<T>), } impl<T> fmt::Display for Error<T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (module, e) = match self { Error::Reqwest(e) => ("reqwest", e.to_string()), Error::Serde(e) => ("serde", e.to_string()), Error::Io(e) => ("IO", e.to_string()), Error::ResponseError(e) => ("response", format!("status code {}", e.status)), }; write!(f, "error in {module}: {e}") } } impl<T: fmt::Debug> error::Error for Error<T> { fn source(&self) -> Option<&(dyn error::Error + 'static)> { Some(match self { Error::Reqwest(e) => e, Error::Serde(e) => e, Error::Io(e) => e, Error::ResponseError(_) => return None, }) } } impl<T> From<reqwest::Error> for Error<T> { fn from(e: reqwest::Error) -> Self { Error::Reqwest(e) } } impl<T> From<serde_json::Error> for Error<T> { fn from(e: serde_json::Error) -> Self { Error::Serde(e) } } impl<T> From<std::io::Error> for Error<T> { fn from(e: std::io::Error) -> Self { Error::Io(e) } } /// Internal use only /// A content type supported by this client. #[allow(dead_code)] enum ContentType { Json, Text, Unsupported(String), } impl From<&str> for ContentType { fn from(content_type: &str) -> Self { if content_type.starts_with("application") && content_type.contains("json") { Self::Json } else if content_type.starts_with("text/plain") { Self::Text } else { Self::Unsupported(content_type.to_string()) } } } pub mod configuration; pub mod ingestion_api; ================================================ FILE: swiftide-langfuse/src/langfuse_batch_manager.rs ================================================ use crate::apis::configuration::Configuration; use crate::apis::ingestion_api::ingestion_batch; use crate::models::{IngestionBatchRequest, IngestionEvent}; use anyhow::Result; use async_trait::async_trait; use dyn_clone::DynClone; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use tokio::sync::Mutex; #[derive(Debug, Default, Clone)] pub struct LangfuseBatchManager { config: Arc<Configuration>, pub batch: Arc<Mutex<Vec<IngestionEvent>>>, dropped: Arc<AtomicBool>, } #[async_trait] pub trait BatchManagerTrait: Send + Sync + DynClone { async fn add_event(&self, event: IngestionEvent); async fn flush(&self) -> anyhow::Result<()>; fn boxed(&self) -> Box<dyn BatchManagerTrait + Send + Sync>; } dyn_clone::clone_trait_object!(BatchManagerTrait); impl LangfuseBatchManager { pub fn new(config: Configuration) -> Self { Self { config: Arc::new(config), batch: Arc::new(Mutex::new(Vec::new())), // Locally track if the manager has been dropped to avoid spawning tasks after drop dropped: Arc::new(AtomicBool::new(false)), } } pub fn spawn(self) { if self.dropped.load(Ordering::Relaxed) { tracing::trace!("LangfuseBatchManager has been dropped, not spawning sender task"); return; } const BATCH_INTERVAL: Duration = Duration::from_secs(5); tokio::spawn(async move { loop { tokio::time::sleep(BATCH_INTERVAL).await; if let Err(e) = self.send_async().await { tracing::error!( error.msg = %e, error.type = %std::any::type_name_of_val(&e), "Failed to send batch to Langfuse" ); } } }); } pub async fn flush(&self) -> Result<()> { let lock = self.batch.lock().await; if !lock.is_empty() { drop(lock); self.send_async().await?; } Ok(()) } pub async fn send_async(&self) -> Result<()> { tracing::trace!("Sending batch to Langfuse"); if self.dropped.load(Ordering::Relaxed) { tracing::error!("LangfuseBatchManager has been dropped, not sending batch"); return Ok(()); } let mut batch_guard = self.batch.lock().await; if batch_guard.is_empty() { return Ok(()); } let batch = std::mem::take(&mut *batch_guard); let mut payload = IngestionBatchRequest { batch, metadata: None, // Optional metadata can be added here if needed }; drop(batch_guard); // Release the lock before making the network call let response = ingestion_batch(&self.config, &payload).await?; for error in &response.errors { // Any errors we log and ignore, no retry tracing::error!( id = %error.id, status = error.status, message = error.message.as_ref().unwrap_or(&None).as_deref().unwrap_or("No message"), error = ?error.error, "Partial failure in batch ingestion" ); } if response.successes.is_empty() { tracing::error!("All items in the batch failed, retrying all items"); let mut batch_guard = self.batch.lock().await; batch_guard.append(&mut payload.batch); } if response.successes.is_empty() && !response.errors.is_empty() { anyhow::bail!("Langfuse ingestion failed for all items"); } else { Ok(()) } } pub async fn add_event(&self, event: IngestionEvent) { self.batch.lock().await.push(event); } } #[async_trait] impl BatchManagerTrait for LangfuseBatchManager { async fn add_event(&self, event: IngestionEvent) { self.add_event(event).await; } async fn flush(&self) -> anyhow::Result<()> { self.flush().await } fn boxed(&self) -> Box<dyn BatchManagerTrait + Send + Sync> { Box::new(self.clone()) } } impl Drop for LangfuseBatchManager { fn drop(&mut self) { if Arc::strong_count(&self.dropped) > 1 { // There are other references to this manager, don't flush yet return; } if self.dropped.swap(true, Ordering::SeqCst) { // Already dropped return; } let this = self.clone(); tokio::task::spawn_blocking(move || { let handle = tokio::runtime::Handle::current(); if let Err(e) = handle.block_on(async move { this.flush().await }) { tracing::error!("Error flushing LangfuseBatchManager on drop: {:?}", e); } }); } } ================================================ FILE: swiftide-langfuse/src/lib.rs ================================================ //! Provides a Langfuse integration for Swiftide //! //! Agents and completion traits will report their input, output, and usage to langfuse. //! //! The `LangfuseLayer` needs to be set up like any other tracing layer. //! //! By default, it requires the LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables //! to be set. You can also provide a custom Langfuse URL via the LANGFUSE_URL environment //! variable. //! //! All `Langfuse` data is on the `debug` level. Make sure your tracing setup captures that level. //! //! # Example //! ```no_run //! # use swiftide_langfuse::LangfuseLayer; //! # use tracing::metadata::LevelFilter; //! # use tracing_subscriber::prelude::*; //! //! // Assuming you have other layers //! let mut layers = Vec::new(); //! layers.push(LangfuseLayer::default().with_filter(LevelFilter::DEBUG).boxed()); //! //! let registry = tracing_subscriber::registry() //! .with(layers); //! //! registry.init(); //! ``` //! //! For more advanced usage, refer to the `LangfuseLayer` documentation. //! //! Refer to the [Langfuse documentation](https://langfuse.com/docs/) for more details on how to setup Langfuse itself. mod apis; mod langfuse_batch_manager; mod models; mod tracing_layer; const DEFAULT_LANGFUSE_URL: &str = "http://localhost:3000"; pub use crate::apis::configuration::Configuration; pub use crate::langfuse_batch_manager::LangfuseBatchManager; pub use crate::tracing_layer::LangfuseLayer; ================================================ FILE: swiftide-langfuse/src/models/create_event_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct CreateEventBody { #[serde( rename = "id", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub id: Option<Option<String>>, #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "startTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub start_time: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde(rename = "level", skip_serializing_if = "Option::is_none")] pub level: Option<models::ObservationLevel>, #[serde( rename = "statusMessage", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub status_message: Option<Option<String>>, #[serde( rename = "parentObservationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub parent_observation_id: Option<Option<String>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, } impl CreateEventBody { pub fn new() -> CreateEventBody { CreateEventBody { id: None, trace_id: None, name: None, start_time: None, metadata: None, input: None, output: None, level: None, status_message: None, parent_observation_id: None, version: None, environment: None, } } } ================================================ FILE: swiftide-langfuse/src/models/create_generation_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct CreateGenerationBody { #[serde( rename = "completionStartTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub completion_start_time: Option<Option<String>>, #[serde( rename = "model", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub model: Option<Option<String>>, #[serde( rename = "modelParameters", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub model_parameters: Option<Option<std::collections::HashMap<String, models::MapValue>>>, #[serde(rename = "usage", skip_serializing_if = "Option::is_none")] pub usage: Option<Box<models::IngestionUsage>>, #[serde(rename = "usageDetails", skip_serializing_if = "Option::is_none")] pub usage_details: Option<Box<models::UsageDetails>>, #[serde( rename = "costDetails", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub cost_details: Option<Option<std::collections::HashMap<String, f64>>>, #[serde( rename = "promptName", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub prompt_name: Option<Option<String>>, #[serde( rename = "promptVersion", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub prompt_version: Option<Option<i32>>, #[serde( rename = "endTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub end_time: Option<Option<String>>, #[serde( rename = "id", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub id: Option<Option<String>>, #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "startTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub start_time: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde(rename = "level", skip_serializing_if = "Option::is_none")] pub level: Option<models::ObservationLevel>, #[serde( rename = "statusMessage", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub status_message: Option<Option<String>>, #[serde( rename = "parentObservationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub parent_observation_id: Option<Option<String>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, } impl CreateGenerationBody { pub fn new() -> CreateGenerationBody { CreateGenerationBody { completion_start_time: None, model: None, model_parameters: None, usage: None, usage_details: None, cost_details: None, prompt_name: None, prompt_version: None, end_time: None, id: None, trace_id: None, name: None, start_time: None, metadata: None, input: None, output: None, level: None, status_message: None, parent_observation_id: None, version: None, environment: None, } } } ================================================ FILE: swiftide-langfuse/src/models/create_score_value.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; /// `CreateScoreValue` : The value of the score. Must be passed as string for categorical scores, /// and numeric for boolean and numeric scores The value of the score. Must be passed as string for /// categorical scores, and numeric for boolean and numeric scores #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum CreateScoreValue { Number(f64), String(String), } impl Default for CreateScoreValue { fn default() -> Self { Self::Number(Default::default()) } } ================================================ FILE: swiftide-langfuse/src/models/create_span_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct CreateSpanBody { #[serde( rename = "endTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub end_time: Option<Option<String>>, #[serde( rename = "id", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub id: Option<Option<String>>, #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "startTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub start_time: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde(rename = "level", skip_serializing_if = "Option::is_none")] pub level: Option<models::ObservationLevel>, #[serde( rename = "statusMessage", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub status_message: Option<Option<String>>, #[serde( rename = "parentObservationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub parent_observation_id: Option<Option<String>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, } impl CreateSpanBody { pub fn new() -> CreateSpanBody { CreateSpanBody { end_time: None, id: None, trace_id: None, name: None, start_time: None, metadata: None, input: None, output: None, level: None, status_message: None, parent_observation_id: None, version: None, environment: None, } } } ================================================ FILE: swiftide-langfuse/src/models/ingestion_batch_request.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct IngestionBatchRequest { /// Batch of tracing events to be ingested. Discriminated by attribute `type`. #[serde(rename = "batch")] pub batch: Vec<models::IngestionEvent>, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_error.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct IngestionError { #[serde(rename = "id")] pub id: String, #[serde(rename = "status")] pub status: i32, #[serde( rename = "message", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub message: Option<Option<String>>, #[serde( rename = "error", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub error: Option<Option<serde_json::Value>>, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use chrono::Utc; use serde::{Deserialize, Serialize}; use uuid::Uuid; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum IngestionEvent { TraceCreate(Box<models::TraceCreate>), ScoreCreate(Box<models::ScoreCreate>), SpanCreate(Box<models::SpanCreate>), SpanUpdate(Box<models::SpanUpdate>), GenerationCreate(Box<models::GenerationCreate>), GenerationUpdate(Box<models::GenerationUpdate>), EventCreate(Box<models::EventCreate>), SdkLog(Box<models::SdkLog>), ObservationCreate(Box<models::ObservationCreate>), ObservationUpdate(Box<models::ObservationUpdate>), } impl Default for IngestionEvent { fn default() -> Self { Self::TraceCreate(Default::default()) } } impl IngestionEvent { pub fn new_trace_create(body: models::TraceBody) -> Self { IngestionEvent::TraceCreate(Box::new(models::TraceCreate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of::Type::TraceCreate, ))) } pub fn new_score_create(body: models::ScoreBody) -> Self { IngestionEvent::ScoreCreate(Box::new(models::ScoreCreate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_1::Type::ScoreCreate, ))) } pub fn new_span_create(body: models::CreateSpanBody) -> Self { IngestionEvent::SpanCreate(Box::new(models::SpanCreate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_2::Type::SpanCreate, ))) } pub fn new_span_update(body: models::UpdateSpanBody) -> Self { IngestionEvent::SpanUpdate(Box::new(models::SpanUpdate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_3::Type::SpanUpdate, ))) } pub fn new_generation_create(body: models::CreateGenerationBody) -> Self { IngestionEvent::GenerationCreate(Box::new(models::GenerationCreate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_4::Type::GenerationCreate, ))) } pub fn new_generation_update(body: models::UpdateGenerationBody) -> Self { IngestionEvent::GenerationUpdate(Box::new(models::GenerationUpdate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_5::Type::GenerationUpdate, ))) } pub fn new_event_create(body: models::CreateEventBody) -> Self { IngestionEvent::EventCreate(Box::new(models::EventCreate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_6::Type::EventCreate, ))) } pub fn new_sdk_log(body: models::SdkLogBody) -> Self { IngestionEvent::SdkLog(Box::new(models::SdkLog::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_7::Type::SdkLog, ))) } pub fn new_observation_create(body: models::ObservationBody) -> Self { IngestionEvent::ObservationCreate(Box::new(models::ObservationCreate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_8::Type::ObservationCreate, ))) } pub fn new_observation_update(body: models::ObservationBody) -> Self { IngestionEvent::ObservationUpdate(Box::new(models::ObservationUpdate::new( body, Uuid::new_v4().to_string(), Utc::now().to_rfc3339(), models::ingestion_event_one_of_9::Type::ObservationUpdate, ))) } } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct TraceCreate { #[serde(rename = "body")] pub body: Box<models::TraceBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl TraceCreate { pub fn new( body: models::TraceBody, id: String, timestamp: String, r#type: Type, ) -> TraceCreate { TraceCreate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "trace-create")] #[default] TraceCreate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_1.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct ScoreCreate { #[serde(rename = "body")] pub body: Box<models::ScoreBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl ScoreCreate { pub fn new( body: models::ScoreBody, id: String, timestamp: String, r#type: Type, ) -> ScoreCreate { ScoreCreate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "score-create")] #[default] ScoreCreate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_2.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct SpanCreate { #[serde(rename = "body")] pub body: Box<models::CreateSpanBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl SpanCreate { pub fn new( body: models::CreateSpanBody, id: String, timestamp: String, r#type: Type, ) -> SpanCreate { SpanCreate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "span-create")] #[default] SpanCreate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_3.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct SpanUpdate { #[serde(rename = "body")] pub body: Box<models::UpdateSpanBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl SpanUpdate { pub fn new( body: models::UpdateSpanBody, id: String, timestamp: String, r#type: Type, ) -> SpanUpdate { SpanUpdate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "span-update")] #[default] SpanUpdate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_4.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct GenerationCreate { #[serde(rename = "body")] pub body: Box<models::CreateGenerationBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl GenerationCreate { pub fn new( body: models::CreateGenerationBody, id: String, timestamp: String, r#type: Type, ) -> GenerationCreate { GenerationCreate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "generation-create")] #[default] GenerationCreate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_5.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct GenerationUpdate { #[serde(rename = "body")] pub body: Box<models::UpdateGenerationBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl GenerationUpdate { pub fn new( body: models::UpdateGenerationBody, id: String, timestamp: String, r#type: Type, ) -> GenerationUpdate { GenerationUpdate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "generation-update")] #[default] GenerationUpdate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_6.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct EventCreate { #[serde(rename = "body")] pub body: Box<models::CreateEventBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl EventCreate { pub fn new( body: models::CreateEventBody, id: String, timestamp: String, r#type: Type, ) -> EventCreate { EventCreate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "event-create")] #[default] EventCreate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_7.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct SdkLog { #[serde(rename = "body")] pub body: Box<models::SdkLogBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl SdkLog { pub fn new(body: models::SdkLogBody, id: String, timestamp: String, r#type: Type) -> SdkLog { SdkLog { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "sdk-log")] #[default] SdkLog, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_8.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct ObservationCreate { #[serde(rename = "body")] pub body: Box<models::ObservationBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl ObservationCreate { pub fn new( body: models::ObservationBody, id: String, timestamp: String, r#type: Type, ) -> ObservationCreate { ObservationCreate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "observation-create")] #[default] ObservationCreate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_event_one_of_9.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct ObservationUpdate { #[serde(rename = "body")] pub body: Box<models::ObservationBody>, /// UUID v4 that identifies the event #[serde(rename = "id")] pub id: String, /// Datetime (ISO 8601) of event creation in client. Should be as close to actual event /// creation in client as possible, this timestamp will be used for ordering of events in /// future release. Resolution: milliseconds (required), microseconds (optimal). #[serde(rename = "timestamp")] pub timestamp: String, /// Optional. Metadata field used by the Langfuse SDKs for debugging. #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "type")] pub r#type: Type, } impl ObservationUpdate { pub fn new( body: models::ObservationBody, id: String, timestamp: String, r#type: Type, ) -> ObservationUpdate { ObservationUpdate { body: Box::new(body), id, timestamp, metadata: None, r#type, } } } #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum Type { #[serde(rename = "observation-update")] #[default] ObservationUpdate, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_response.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct IngestionResponse { #[serde(rename = "successes")] pub successes: Vec<models::IngestionSuccess>, #[serde(rename = "errors")] pub errors: Vec<models::IngestionError>, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_success.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct IngestionSuccess { #[serde(rename = "id")] pub id: String, #[serde(rename = "status")] pub status: i32, } ================================================ FILE: swiftide-langfuse/src/models/ingestion_usage.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum IngestionUsage { Usage(Box<models::Usage>), OpenAiUsage(Box<models::OpenAiUsage>), } impl Default for IngestionUsage { fn default() -> Self { Self::Usage(Default::default()) } } ================================================ FILE: swiftide-langfuse/src/models/map_value.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum MapValue { String(String), Integer(i32), Boolean(bool), Array(Vec<String>), } impl Default for MapValue { fn default() -> Self { Self::String(Default::default()) } } ================================================ FILE: swiftide-langfuse/src/models/mod.rs ================================================ // pub mod annotation_queue; // pub use self::annotation_queue::AnnotationQueue; // pub mod annotation_queue_assignment_request; // pub use self::annotation_queue_assignment_request::AnnotationQueueAssignmentRequest; // pub mod annotation_queue_item; // pub use self::annotation_queue_item::AnnotationQueueItem; // pub mod annotation_queue_object_type; // pub use self::annotation_queue_object_type::AnnotationQueueObjectType; // pub mod annotation_queue_status; // pub use self::annotation_queue_status::AnnotationQueueStatus; // pub mod api_key_deletion_response; // pub use self::api_key_deletion_response::ApiKeyDeletionResponse; // pub mod api_key_list; // pub use self::api_key_list::ApiKeyList; // pub mod api_key_response; // pub use self::api_key_response::ApiKeyResponse; // pub mod api_key_summary; // pub use self::api_key_summary::ApiKeySummary; // pub mod authentication_scheme; // pub use self::authentication_scheme::AuthenticationScheme; // pub mod base_event; // pub use self::base_event::BaseEvent; // pub mod base_prompt; // pub use self::base_prompt::BasePrompt; // pub mod base_score; // pub use self::base_score::BaseScore; // pub mod base_score_v1; // pub use self::base_score_v1::BaseScoreV1; // pub mod boolean_score; // pub use self::boolean_score::BooleanScore; // pub mod boolean_score_v1; // pub use self::boolean_score_v1::BooleanScoreV1; // pub mod bulk_config; // pub use self::bulk_config::BulkConfig; // pub mod categorical_score; // pub use self::categorical_score::CategoricalScore; // pub mod categorical_score_v1; // pub use self::categorical_score_v1::CategoricalScoreV1; // pub mod chat_message; // pub use self::chat_message::ChatMessage; // pub mod chat_message_with_placeholders; // pub use self::chat_message_with_placeholders::ChatMessageWithPlaceholders; // pub mod chat_message_with_placeholders_one_of; // pub use self::chat_message_with_placeholders_one_of::ChatMessageWithPlaceholdersOneOf; // pub mod chat_message_with_placeholders_one_of_1; // pub use self::chat_message_with_placeholders_one_of_1::ChatMessageWithPlaceholdersOneOf1; // pub mod chat_prompt; // pub use self::chat_prompt::ChatPrompt; // pub mod comment; // pub use self::comment::Comment; // pub mod comment_object_type; // pub use self::comment_object_type::CommentObjectType; // pub mod config_category; // pub use self::config_category::ConfigCategory; // pub mod create_annotation_queue_assignment_response; // pub use self::create_annotation_queue_assignment_response::CreateAnnotationQueueAssignmentResponse; // pub mod create_annotation_queue_item_request; // pub use self::create_annotation_queue_item_request::CreateAnnotationQueueItemRequest; // pub mod create_annotation_queue_request; // pub use self::create_annotation_queue_request::CreateAnnotationQueueRequest; // pub mod create_chat_prompt_request; // pub use self::create_chat_prompt_request::CreateChatPromptRequest; // pub mod create_comment_request; // pub use self::create_comment_request::CreateCommentRequest; // pub mod create_comment_response; // pub use self::create_comment_response::CreateCommentResponse; // pub mod create_dataset_item_request; // pub use self::create_dataset_item_request::CreateDatasetItemRequest; // pub mod create_dataset_request; // pub use self::create_dataset_request::CreateDatasetRequest; // pub mod create_dataset_run_item_request; // pub use self::create_dataset_run_item_request::CreateDatasetRunItemRequest; pub mod create_event_body; pub use self::create_event_body::CreateEventBody; // pub mod create_event_event; // pub use self::create_event_event::CreateEventEvent; pub mod create_generation_body; pub use self::create_generation_body::CreateGenerationBody; // pub mod create_generation_event; // pub use self::create_generation_event::CreateGenerationEvent; // pub mod create_model_request; // pub use self::create_model_request::CreateModelRequest; // pub mod create_observation_event; // pub use self::create_observation_event::CreateObservationEvent; // pub mod create_prompt_request; // pub use self::create_prompt_request::CreatePromptRequest; // pub mod create_prompt_request_one_of; // pub use self::create_prompt_request_one_of::CreatePromptRequestOneOf; // pub mod create_prompt_request_one_of_1; // pub use self::create_prompt_request_one_of_1::CreatePromptRequestOneOf1; // pub mod create_score_config_request; // pub use self::create_score_config_request::CreateScoreConfigRequest; // pub mod create_score_request; // pub use self::create_score_request::CreateScoreRequest; // pub mod create_score_response; // pub use self::create_score_response::CreateScoreResponse; pub mod create_score_value; pub use self::create_score_value::CreateScoreValue; pub mod create_span_body; pub use self::create_span_body::CreateSpanBody; // pub mod create_span_event; // pub use self::create_span_event::CreateSpanEvent; // pub mod create_text_prompt_request; // pub use self::create_text_prompt_request::CreateTextPromptRequest; // pub mod dataset; // pub use self::dataset::Dataset; // pub mod dataset_item; // pub use self::dataset_item::DatasetItem; // pub mod dataset_run; // pub use self::dataset_run::DatasetRun; // pub mod dataset_run_item; // pub use self::dataset_run_item::DatasetRunItem; // pub mod dataset_run_with_items; // pub use self::dataset_run_with_items::DatasetRunWithItems; // pub mod dataset_status; // pub use self::dataset_status::DatasetStatus; // pub mod delete_annotation_queue_assignment_response; // pub use self::delete_annotation_queue_assignment_response::DeleteAnnotationQueueAssignmentResponse; // pub mod delete_annotation_queue_item_response; // pub use self::delete_annotation_queue_item_response::DeleteAnnotationQueueItemResponse; // pub mod delete_dataset_item_response; // pub use self::delete_dataset_item_response::DeleteDatasetItemResponse; // pub mod delete_dataset_run_response; // pub use self::delete_dataset_run_response::DeleteDatasetRunResponse; // pub mod delete_membership_request; // pub use self::delete_membership_request::DeleteMembershipRequest; // pub mod delete_trace_response; // pub use self::delete_trace_response::DeleteTraceResponse; // pub mod filter_config; // pub use self::filter_config::FilterConfig; // pub mod get_comments_response; // pub use self::get_comments_response::GetCommentsResponse; // pub mod get_media_response; // pub use self::get_media_response::GetMediaResponse; // pub mod get_media_upload_url_request; // pub use self::get_media_upload_url_request::GetMediaUploadUrlRequest; // pub mod get_media_upload_url_response; // pub use self::get_media_upload_url_response::GetMediaUploadUrlResponse; // pub mod get_scores_response; // pub use self::get_scores_response::GetScoresResponse; // pub mod get_scores_response_data; // pub use self::get_scores_response_data::GetScoresResponseData; // pub mod get_scores_response_data_boolean; // pub use self::get_scores_response_data_boolean::GetScoresResponseDataBoolean; // pub mod get_scores_response_data_categorical; // pub use self::get_scores_response_data_categorical::GetScoresResponseDataCategorical; // pub mod get_scores_response_data_numeric; // pub use self::get_scores_response_data_numeric::GetScoresResponseDataNumeric; // pub mod get_scores_response_data_one_of; // pub use self::get_scores_response_data_one_of::GetScoresResponseDataOneOf; // pub mod get_scores_response_data_one_of_1; // pub use self::get_scores_response_data_one_of_1::GetScoresResponseDataOneOf1; // pub mod get_scores_response_data_one_of_2; // pub use self::get_scores_response_data_one_of_2::GetScoresResponseDataOneOf2; // pub mod get_scores_response_trace_data; // pub use self::get_scores_response_trace_data::GetScoresResponseTraceData; // pub mod health_response; // pub use self::health_response::HealthResponse; pub mod ingestion_batch_request; pub use self::ingestion_batch_request::IngestionBatchRequest; pub mod ingestion_error; pub use self::ingestion_error::IngestionError; pub mod ingestion_event; pub use self::ingestion_event::IngestionEvent; pub mod ingestion_event_one_of; pub use self::ingestion_event_one_of::TraceCreate; pub mod ingestion_event_one_of_1; pub use self::ingestion_event_one_of_1::ScoreCreate; pub mod ingestion_event_one_of_2; pub use self::ingestion_event_one_of_2::SpanCreate; pub mod ingestion_event_one_of_3; pub use self::ingestion_event_one_of_3::SpanUpdate; pub mod ingestion_event_one_of_4; pub use self::ingestion_event_one_of_4::GenerationCreate; pub mod ingestion_event_one_of_5; pub use self::ingestion_event_one_of_5::GenerationUpdate; pub mod ingestion_event_one_of_6; pub use self::ingestion_event_one_of_6::EventCreate; pub mod ingestion_event_one_of_7; pub use self::ingestion_event_one_of_7::SdkLog; pub mod ingestion_event_one_of_8; pub use self::ingestion_event_one_of_8::ObservationCreate; pub mod ingestion_event_one_of_9; pub use self::ingestion_event_one_of_9::ObservationUpdate; pub mod ingestion_response; pub use self::ingestion_response::IngestionResponse; pub mod ingestion_success; pub use self::ingestion_success::IngestionSuccess; pub mod ingestion_usage; pub use self::ingestion_usage::IngestionUsage; // pub mod llm_adapter; // pub use self::llm_adapter::LlmAdapter; // pub mod llm_connection; // pub use self::llm_connection::LlmConnection; pub mod map_value; pub use self::map_value::MapValue; // pub mod media_content_type; // pub use self::media_content_type::MediaContentType; // pub mod membership_deletion_response; // pub use self::membership_deletion_response::MembershipDeletionResponse; // pub mod membership_request; // pub use self::membership_request::MembershipRequest; // pub mod membership_response; // pub use self::membership_response::MembershipResponse; // pub mod membership_role; // pub use self::membership_role::MembershipRole; // pub mod memberships_response; // pub use self::memberships_response::MembershipsResponse; // pub mod metrics_response; // pub use self::metrics_response::MetricsResponse; // pub mod model; // pub use self::model::Model; // pub mod model_price; // pub use self::model_price::ModelPrice; pub mod model_usage_unit; pub use self::model_usage_unit::ModelUsageUnit; // pub mod numeric_score; // pub use self::numeric_score::NumericScore; // pub mod numeric_score_v1; // pub use self::numeric_score_v1::NumericScoreV1; // pub mod observation; // pub use self::observation::Observation; pub mod observation_body; pub use self::observation_body::ObservationBody; pub mod observation_level; pub use self::observation_level::ObservationLevel; pub mod observation_type; pub use self::observation_type::ObservationType; // pub mod observations; // pub use self::observations::Observations; // pub mod observations_view; // pub use self::observations_view::ObservationsView; // pub mod observations_views; // pub use self::observations_views::ObservationsViews; pub mod open_ai_completion_usage_schema; pub use self::open_ai_completion_usage_schema::OpenAiCompletionUsageSchema; pub mod open_ai_response_usage_schema; pub use self::open_ai_response_usage_schema::OpenAiResponseUsageSchema; pub mod open_ai_usage; pub use self::open_ai_usage::OpenAiUsage; pub mod optional_observation_body; // pub mod organization_project; // pub use self::organization_project::OrganizationProject; // pub mod organization_projects_response; // pub use self::organization_projects_response::OrganizationProjectsResponse; // pub mod paginated_annotation_queue_items; // pub use self::paginated_annotation_queue_items::PaginatedAnnotationQueueItems; // pub mod paginated_annotation_queues; // pub use self::paginated_annotation_queues::PaginatedAnnotationQueues; // pub mod paginated_dataset_items; // pub use self::paginated_dataset_items::PaginatedDatasetItems; // pub mod paginated_dataset_run_items; // pub use self::paginated_dataset_run_items::PaginatedDatasetRunItems; // pub mod paginated_dataset_runs; // pub use self::paginated_dataset_runs::PaginatedDatasetRuns; // pub mod paginated_datasets; // pub use self::paginated_datasets::PaginatedDatasets; // pub mod paginated_llm_connections; // pub use self::paginated_llm_connections::PaginatedLlmConnections; // pub mod paginated_models; // pub use self::paginated_models::PaginatedModels; // pub mod paginated_sessions; // pub use self::paginated_sessions::PaginatedSessions; // pub mod patch_media_body; // pub use self::patch_media_body::PatchMediaBody; // pub mod placeholder_message; // pub use self::placeholder_message::PlaceholderMessage; // pub mod project; // pub use self::project::Project; // pub mod project_deletion_response; // pub use self::project_deletion_response::ProjectDeletionResponse; // pub mod projects; // pub use self::projects::Projects; // pub mod projects_create_api_key_request; // pub use self::projects_create_api_key_request::ProjectsCreateApiKeyRequest; // pub mod projects_create_request; // pub use self::projects_create_request::ProjectsCreateRequest; // pub mod prompt; // pub use self::prompt::Prompt; // pub mod prompt_meta; // pub use self::prompt_meta::PromptMeta; // pub mod prompt_meta_list_response; // pub use self::prompt_meta_list_response::PromptMetaListResponse; // pub mod prompt_one_of; // pub use self::prompt_one_of::PromptOneOf; // pub mod prompt_one_of_1; // pub use self::prompt_one_of_1::PromptOneOf1; // pub mod prompt_version_update_request; // pub use self::prompt_version_update_request::PromptVersionUpdateRequest; // pub mod resource_meta; // pub use self::resource_meta::ResourceMeta; // pub mod resource_type; // pub use self::resource_type::ResourceType; // pub mod resource_types_response; // pub use self::resource_types_response::ResourceTypesResponse; // pub mod schema_extension; // pub use self::schema_extension::SchemaExtension; // pub mod schema_resource; // pub use self::schema_resource::SchemaResource; // pub mod schemas_response; // pub use self::schemas_response::SchemasResponse; // pub mod scim_create_user_request; // pub use self::scim_create_user_request::ScimCreateUserRequest; // pub mod scim_email; // pub use self::scim_email::ScimEmail; // pub mod scim_feature_support; // pub use self::scim_feature_support::ScimFeatureSupport; // pub mod scim_name; // pub use self::scim_name::ScimName; // pub mod scim_user; // pub use self::scim_user::ScimUser; // pub mod scim_users_list_response; // pub use self::scim_users_list_response::ScimUsersListResponse; // pub mod score; // pub use self::score::Score; pub mod score_body; pub use self::score_body::ScoreBody; // pub mod score_config; // pub use self::score_config::ScoreConfig; // pub mod score_configs; // pub use self::score_configs::ScoreConfigs; pub mod score_data_type; pub use self::score_data_type::ScoreDataType; // pub mod score_event; // pub use self::score_event::ScoreEvent; // pub mod score_one_of; // pub use self::score_one_of::ScoreOneOf; // pub mod score_one_of_1; // pub use self::score_one_of_1::ScoreOneOf1; // pub mod score_one_of_2; // pub use self::score_one_of_2::ScoreOneOf2; // pub mod score_source; // pub use self::score_source::ScoreSource; // pub mod score_v1; // pub use self::score_v1::ScoreV1; // pub mod score_v1_one_of; // pub use self::score_v1_one_of::ScoreV1OneOf; // pub mod score_v1_one_of_1; // pub use self::score_v1_one_of_1::ScoreV1OneOf1; // pub mod score_v1_one_of_2; // pub use self::score_v1_one_of_2::ScoreV1OneOf2; pub mod sdk_log_body; pub use self::sdk_log_body::SdkLogBody; // pub mod sdk_log_event; // pub use self::sdk_log_event::SdkLogEvent; // pub mod service_provider_config; // pub use self::service_provider_config::ServiceProviderConfig; // pub mod session; // pub use self::session::Session; // pub mod session_with_traces; // pub use self::session_with_traces::SessionWithTraces; // pub mod sort; // pub use self::sort::Sort; // pub mod text_prompt; // pub use self::text_prompt::TextPrompt; // pub mod trace; // pub use self::trace::Trace; pub mod trace_body; pub use self::trace_body::TraceBody; // pub mod trace_delete_multiple_request; // pub use self::trace_delete_multiple_request::TraceDeleteMultipleRequest; // pub mod trace_event; // pub use self::trace_event::TraceEvent; // pub mod trace_with_details; // pub use self::trace_with_details::TraceWithDetails; // pub mod trace_with_full_details; // pub use self::trace_with_full_details::TraceWithFullDetails; // pub mod traces; // pub use self::traces::Traces; // pub mod update_annotation_queue_item_request; // pub use self::update_annotation_queue_item_request::UpdateAnnotationQueueItemRequest; // pub mod update_event_body; // pub use self::update_event_body::UpdateEventBody; pub mod update_generation_body; pub use self::update_generation_body::UpdateGenerationBody; // pub mod update_generation_event; // pub use self::update_generation_event::UpdateGenerationEvent; // pub mod update_observation_event; // pub use self::update_observation_event::UpdateObservationEvent; pub mod update_span_body; pub use self::update_span_body::UpdateSpanBody; // pub mod update_span_event; // pub use self::update_span_event::UpdateSpanEvent; // pub mod upsert_llm_connection_request; // pub use self::upsert_llm_connection_request::UpsertLlmConnectionRequest; pub mod usage; pub use self::usage::Usage; pub mod usage_details; pub use self::usage_details::UsageDetails; // pub mod user_meta; // pub use self::user_meta::UserMeta; // pub mod utils_meta_response; // pub use self::utils_meta_response::UtilsMetaResponse; ================================================ FILE: swiftide-langfuse/src/models/model_usage_unit.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; /// `ModelUsageUnit` : Unit of usage in Langfuse /// Unit of usage in Langfuse #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum ModelUsageUnit { #[serde(rename = "CHARACTERS")] #[default] Characters, #[serde(rename = "TOKENS")] Tokens, #[serde(rename = "MILLISECONDS")] Milliseconds, #[serde(rename = "SECONDS")] Seconds, #[serde(rename = "IMAGES")] Images, #[serde(rename = "REQUESTS")] Requests, } impl std::fmt::Display for ModelUsageUnit { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::Characters => write!(f, "CHARACTERS"), Self::Tokens => write!(f, "TOKENS"), Self::Milliseconds => write!(f, "MILLISECONDS"), Self::Seconds => write!(f, "SECONDS"), Self::Images => write!(f, "IMAGES"), Self::Requests => write!(f, "REQUESTS"), } } } ================================================ FILE: swiftide-langfuse/src/models/observation_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct ObservationBody { #[serde( rename = "id", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub id: Option<Option<String>>, #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde(rename = "type")] pub r#type: models::ObservationType, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "startTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub start_time: Option<Option<String>>, #[serde( rename = "endTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub end_time: Option<Option<String>>, #[serde( rename = "completionStartTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub completion_start_time: Option<Option<String>>, #[serde( rename = "model", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub model: Option<Option<String>>, #[serde( rename = "modelParameters", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub model_parameters: Option<Option<std::collections::HashMap<String, models::MapValue>>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde(rename = "usage", skip_serializing_if = "Option::is_none")] pub usage: Option<Box<models::Usage>>, #[serde(rename = "level", skip_serializing_if = "Option::is_none")] pub level: Option<models::ObservationLevel>, #[serde( rename = "statusMessage", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub status_message: Option<Option<String>>, #[serde( rename = "parentObservationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub parent_observation_id: Option<Option<String>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, } impl ObservationBody { pub fn new(r#type: models::ObservationType) -> ObservationBody { ObservationBody { id: None, trace_id: None, r#type, name: None, start_time: None, end_time: None, completion_start_time: None, model: None, model_parameters: None, input: None, version: None, metadata: None, output: None, usage: None, level: None, status_message: None, parent_observation_id: None, environment: None, } } } ================================================ FILE: swiftide-langfuse/src/models/observation_level.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum ObservationLevel { #[serde(rename = "DEBUG")] #[default] Debug, #[serde(rename = "DEFAULT")] Default, #[serde(rename = "WARNING")] Warning, #[serde(rename = "ERROR")] Error, } impl std::fmt::Display for ObservationLevel { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::Debug => write!(f, "DEBUG"), Self::Default => write!(f, "DEFAULT"), Self::Warning => write!(f, "WARNING"), Self::Error => write!(f, "ERROR"), } } } ================================================ FILE: swiftide-langfuse/src/models/observation_type.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum ObservationType { #[serde(rename = "SPAN")] #[default] Span, #[serde(rename = "GENERATION")] Generation, #[serde(rename = "EVENT")] Event, #[serde(rename = "AGENT")] Agent, #[serde(rename = "TOOL")] Tool, #[serde(rename = "CHAIN")] Chain, #[serde(rename = "RETRIEVER")] Retriever, #[serde(rename = "EVALUATOR")] Evaluator, #[serde(rename = "EMBEDDING")] Embedding, #[serde(rename = "GUARDRAIL")] Guardrail, } impl std::fmt::Display for ObservationType { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::Span => write!(f, "SPAN"), Self::Generation => write!(f, "GENERATION"), Self::Event => write!(f, "EVENT"), Self::Agent => write!(f, "AGENT"), Self::Tool => write!(f, "TOOL"), Self::Chain => write!(f, "CHAIN"), Self::Retriever => write!(f, "RETRIEVER"), Self::Evaluator => write!(f, "EVALUATOR"), Self::Embedding => write!(f, "EMBEDDING"), Self::Guardrail => write!(f, "GUARDRAIL"), } } } ================================================ FILE: swiftide-langfuse/src/models/open_ai_completion_usage_schema.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; /// `OpenAiCompletionUsageSchema` : `OpenAI` Usage schema from (Chat-)Completion APIs #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct OpenAiCompletionUsageSchema { #[serde(rename = "prompt_tokens")] pub prompt_tokens: i32, #[serde(rename = "completion_tokens")] pub completion_tokens: i32, #[serde(rename = "total_tokens")] pub total_tokens: i32, #[serde( rename = "prompt_tokens_details", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub prompt_tokens_details: Option<Option<std::collections::HashMap<String, i32>>>, #[serde( rename = "completion_tokens_details", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub completion_tokens_details: Option<Option<std::collections::HashMap<String, i32>>>, } impl OpenAiCompletionUsageSchema { /// `OpenAI` Usage schema from (Chat-)Completion APIs pub fn new( prompt_tokens: i32, completion_tokens: i32, total_tokens: i32, ) -> OpenAiCompletionUsageSchema { OpenAiCompletionUsageSchema { prompt_tokens, completion_tokens, total_tokens, prompt_tokens_details: None, completion_tokens_details: None, } } } ================================================ FILE: swiftide-langfuse/src/models/open_ai_response_usage_schema.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; /// `OpenAiResponseUsageSchema` : `OpenAI` Usage schema from Response API #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct OpenAiResponseUsageSchema { #[serde(rename = "input_tokens")] pub input_tokens: i32, #[serde(rename = "output_tokens")] pub output_tokens: i32, #[serde(rename = "total_tokens")] pub total_tokens: i32, #[serde( rename = "input_tokens_details", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input_tokens_details: Option<Option<std::collections::HashMap<String, i32>>>, #[serde( rename = "output_tokens_details", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output_tokens_details: Option<Option<std::collections::HashMap<String, i32>>>, } impl OpenAiResponseUsageSchema { /// `OpenAI` Usage schema from Response API pub fn new( input_tokens: i32, output_tokens: i32, total_tokens: i32, ) -> OpenAiResponseUsageSchema { OpenAiResponseUsageSchema { input_tokens, output_tokens, total_tokens, input_tokens_details: None, output_tokens_details: None, } } } ================================================ FILE: swiftide-langfuse/src/models/open_ai_usage.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; /// `OpenAiUsage` : Usage interface of `OpenAI` for improved compatibility. #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct OpenAiUsage { #[serde( rename = "promptTokens", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub prompt_tokens: Option<Option<i32>>, #[serde( rename = "completionTokens", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub completion_tokens: Option<Option<i32>>, #[serde( rename = "totalTokens", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub total_tokens: Option<Option<i32>>, } impl OpenAiUsage { /// Usage interface of `OpenAI` for improved compatibility. pub fn new() -> OpenAiUsage { OpenAiUsage { prompt_tokens: None, completion_tokens: None, total_tokens: None, } } } ================================================ FILE: swiftide-langfuse/src/models/optional_observation_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[allow(dead_code)] #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct OptionalObservationBody { #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "startTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub start_time: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde(rename = "level", skip_serializing_if = "Option::is_none")] pub level: Option<models::ObservationLevel>, #[serde( rename = "statusMessage", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub status_message: Option<Option<String>>, #[serde( rename = "parentObservationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub parent_observation_id: Option<Option<String>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, } ================================================ FILE: swiftide-langfuse/src/models/score_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct ScoreBody { #[serde( rename = "id", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub id: Option<Option<String>>, #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde( rename = "sessionId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub session_id: Option<Option<String>>, #[serde( rename = "observationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub observation_id: Option<Option<String>>, #[serde( rename = "datasetRunId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub dataset_run_id: Option<Option<String>>, #[serde(rename = "name")] pub name: String, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, #[serde(rename = "value")] pub value: Box<models::CreateScoreValue>, #[serde( rename = "comment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub comment: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde(rename = "dataType", skip_serializing_if = "Option::is_none")] pub data_type: Option<models::ScoreDataType>, /// Reference a score config on a score. When set, the score name must equal the config name /// and scores must comply with the config's range and data type. For categorical scores, the /// value must map to a config category. Numeric scores might be constrained by the score /// config's max and min values #[serde( rename = "configId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub config_id: Option<Option<String>>, } impl ScoreBody { pub fn new(name: String, value: models::CreateScoreValue) -> ScoreBody { ScoreBody { id: None, trace_id: None, session_id: None, observation_id: None, dataset_run_id: None, name, environment: None, value: Box::new(value), comment: None, metadata: None, data_type: None, config_id: None, } } } ================================================ FILE: swiftide-langfuse/src/models/score_data_type.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive( Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize, Default, )] pub enum ScoreDataType { #[serde(rename = "NUMERIC")] #[default] Numeric, #[serde(rename = "BOOLEAN")] Boolean, #[serde(rename = "CATEGORICAL")] Categorical, } impl std::fmt::Display for ScoreDataType { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::Numeric => write!(f, "NUMERIC"), Self::Boolean => write!(f, "BOOLEAN"), Self::Categorical => write!(f, "CATEGORICAL"), } } } ================================================ FILE: swiftide-langfuse/src/models/sdk_log_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct SdkLogBody { #[serde(rename = "log", deserialize_with = "Option::deserialize")] pub log: Option<serde_json::Value>, } impl SdkLogBody { pub fn new(log: Option<serde_json::Value>) -> SdkLogBody { SdkLogBody { log } } } ================================================ FILE: swiftide-langfuse/src/models/trace_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct TraceBody { #[serde( rename = "id", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub id: Option<Option<String>>, #[serde( rename = "timestamp", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub timestamp: Option<Option<String>>, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "userId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub user_id: Option<Option<String>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde( rename = "sessionId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub session_id: Option<Option<String>>, #[serde( rename = "release", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub release: Option<Option<String>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "tags", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub tags: Option<Option<Vec<String>>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, /// Make trace publicly accessible via url #[serde( rename = "public", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub public: Option<Option<bool>>, } impl TraceBody { pub fn new() -> TraceBody { TraceBody { id: None, timestamp: None, name: None, user_id: None, input: None, output: None, session_id: None, release: None, version: None, metadata: None, tags: None, environment: None, public: None, } } } ================================================ FILE: swiftide-langfuse/src/models/update_generation_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct UpdateGenerationBody { #[serde( rename = "completionStartTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub completion_start_time: Option<Option<String>>, #[serde( rename = "model", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub model: Option<Option<String>>, #[serde( rename = "modelParameters", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub model_parameters: Option<Option<std::collections::HashMap<String, models::MapValue>>>, #[serde(rename = "usage", skip_serializing_if = "Option::is_none")] pub usage: Option<Box<models::IngestionUsage>>, #[serde( rename = "promptName", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub prompt_name: Option<Option<String>>, #[serde(rename = "usageDetails", skip_serializing_if = "Option::is_none")] pub usage_details: Option<Box<models::UsageDetails>>, #[serde( rename = "costDetails", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub cost_details: Option<Option<std::collections::HashMap<String, f64>>>, #[serde( rename = "promptVersion", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub prompt_version: Option<Option<i32>>, #[serde( rename = "endTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub end_time: Option<Option<String>>, #[serde(rename = "id")] pub id: String, #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "startTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub start_time: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde(rename = "level", skip_serializing_if = "Option::is_none")] pub level: Option<models::ObservationLevel>, #[serde( rename = "statusMessage", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub status_message: Option<Option<String>>, #[serde( rename = "parentObservationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub parent_observation_id: Option<Option<String>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, } impl UpdateGenerationBody { pub fn new(id: String) -> UpdateGenerationBody { UpdateGenerationBody { completion_start_time: None, model: None, model_parameters: None, usage: None, prompt_name: None, usage_details: None, cost_details: None, prompt_version: None, end_time: None, id, trace_id: None, name: None, start_time: None, metadata: None, input: None, output: None, level: None, status_message: None, parent_observation_id: None, version: None, environment: None, } } } ================================================ FILE: swiftide-langfuse/src/models/update_span_body.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct UpdateSpanBody { #[serde( rename = "endTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub end_time: Option<Option<String>>, #[serde(rename = "id")] pub id: String, #[serde( rename = "traceId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub trace_id: Option<Option<String>>, #[serde( rename = "name", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub name: Option<Option<String>>, #[serde( rename = "startTime", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub start_time: Option<Option<String>>, #[serde( rename = "metadata", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub metadata: Option<Option<serde_json::Value>>, #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<serde_json::Value>>, #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<serde_json::Value>>, #[serde(rename = "level", skip_serializing_if = "Option::is_none")] pub level: Option<models::ObservationLevel>, #[serde( rename = "statusMessage", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub status_message: Option<Option<String>>, #[serde( rename = "parentObservationId", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub parent_observation_id: Option<Option<String>>, #[serde( rename = "version", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub version: Option<Option<String>>, #[serde( rename = "environment", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub environment: Option<Option<String>>, } impl UpdateSpanBody { pub fn new(id: String) -> UpdateSpanBody { UpdateSpanBody { end_time: None, id, trace_id: None, name: None, start_time: None, metadata: None, input: None, output: None, level: None, status_message: None, parent_observation_id: None, version: None, environment: None, } } } ================================================ FILE: swiftide-langfuse/src/models/usage.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; /// Usage : (Deprecated. Use usageDetails and costDetails instead.) Standard interface for usage and /// cost #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] pub struct Usage { /// Number of input units (e.g. tokens) #[serde( rename = "input", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input: Option<Option<i32>>, /// Number of output units (e.g. tokens) #[serde( rename = "output", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output: Option<Option<i32>>, /// Defaults to input+output if not set #[serde( rename = "total", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub total: Option<Option<i32>>, #[serde(rename = "unit", skip_serializing_if = "Option::is_none")] pub unit: Option<models::ModelUsageUnit>, /// USD input cost #[serde( rename = "inputCost", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub input_cost: Option<Option<f64>>, /// USD output cost #[serde( rename = "outputCost", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub output_cost: Option<Option<f64>>, /// USD total cost, defaults to input+output #[serde( rename = "totalCost", default, with = "::serde_with::rust::double_option", skip_serializing_if = "Option::is_none" )] pub total_cost: Option<Option<f64>>, } impl Usage { /// (Deprecated. Use usageDetails and costDetails instead.) Standard interface for usage and /// cost pub fn new() -> Usage { Usage { input: None, output: None, total: None, unit: None, input_cost: None, output_cost: None, total_cost: None, } } } impl From<swiftide_core::chat_completion::Usage> for Usage { fn from(value: swiftide_core::chat_completion::Usage) -> Self { Usage { input: Some(Some(value.prompt_tokens as i32)), output: Some(Some(value.completion_tokens as i32)), total: Some(Some(value.total_tokens as i32)), unit: Some(models::ModelUsageUnit::Tokens), input_cost: None, output_cost: None, total_cost: None, } } } ================================================ FILE: swiftide-langfuse/src/models/usage_details.rs ================================================ // langfuse // // ## Authentication Authenticate with the API using [Basic Auth](https://en.wikipedia.org/wiki/Basic_access_authentication), get API keys in the project settings: - username: Langfuse Public Key - password: Langfuse Secret Key ## Exports - OpenAPI spec: https://cloud.langfuse.com/generated/api/openapi.yml - Postman collection: https://cloud.langfuse.com/generated/postman/collection.json // // The version of the OpenAPI document: // // Generated by: https://openapi-generator.tech use crate::models; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum UsageDetails { Object(std::collections::HashMap<String, i32>), OpenAiCompletionUsageSchema(Box<models::OpenAiCompletionUsageSchema>), OpenAiResponseUsageSchema(Box<models::OpenAiResponseUsageSchema>), } impl Default for UsageDetails { fn default() -> Self { Self::Object(Default::default()) } } ================================================ FILE: swiftide-langfuse/src/tracing_layer.rs ================================================ use anyhow::Context as _; use chrono::Utc; use reqwest::Client; use serde_json::Value; use std::collections::HashMap; use std::str::FromStr as _; use std::sync::Arc; use std::{env, fmt}; use tokio::sync::Mutex; use tracing::field::{Field, Visit}; use tracing::{Event, Id, Level, Metadata, Subscriber, span}; use tracing_subscriber::Layer; use tracing_subscriber::layer::Context; use tracing_subscriber::registry::LookupSpan; use uuid::Uuid; use crate::langfuse_batch_manager::{BatchManagerTrait, LangfuseBatchManager}; use crate::models::{ IngestionEvent, ObservationBody, ObservationLevel, ObservationType, TraceBody, }; use crate::{Configuration, DEFAULT_LANGFUSE_URL}; #[derive(Default, Debug, Clone)] pub struct SpanData { pub observation_id: String, // Langfuse requires ids to be UUID v4 strings pub name: String, pub start_time: String, pub level: ObservationLevel, pub metadata: serde_json::Map<String, Value>, pub parent_span_id: Option<u64>, } impl SpanData { pub fn get<T>(&self, key: &str) -> Option<T> where T: serde::de::DeserializeOwned, { if let Some(value) = self.metadata.get(key) { let parsed = serde_json::from_value(value.clone()); if let Err(e) = &parsed { tracing::warn!( error.msg = %e, error.type = %std::any::type_name_of_val(e), key = %key, value = %value, "[Langfuse] Failed to parse metadata field" ); } return parsed.ok(); } None } /// Returns metadata with all keys that do not start with "langfuse." #[must_use] pub fn remaining_metadata(&self) -> Option<serde_json::Map<String, Value>> { let mut metadata = self.metadata.clone(); metadata.retain(|k, _| !k.starts_with("langfuse.")); if metadata.is_empty() { None } else { Some(metadata) } } } impl From<serde_json::Map<String, Value>> for SpanData { fn from(metadata: serde_json::Map<String, Value>) -> Self { SpanData { metadata, ..Default::default() } } } pub fn map_level(level: &Level) -> ObservationLevel { use ObservationLevel::{Debug, Default, Error, Warning}; match *level { Level::ERROR => Error, Level::WARN => Warning, Level::INFO => Default, Level::DEBUG => Debug, Level::TRACE => Debug, } } #[derive(Debug)] pub struct SpanTracker { active_spans: HashMap<u64, (String, ObservationType)>, current_trace_id: Option<String>, } impl Default for SpanTracker { fn default() -> Self { Self::new() } } impl SpanTracker { pub fn new() -> Self { Self { active_spans: HashMap::new(), current_trace_id: None, } } pub fn add_span(&mut self, span_id: u64, observation_id: String, ty: ObservationType) { self.active_spans.insert(span_id, (observation_id, ty)); } pub fn get_span(&self, span_id: u64) -> Option<&(String, ObservationType)> { self.active_spans.get(&span_id) } pub fn remove_span(&mut self, span_id: u64) -> Option<(String, ObservationType)> { self.active_spans.remove(&span_id) } } #[derive(Clone)] pub struct LangfuseLayer { pub batch_manager: Box<dyn BatchManagerTrait>, pub span_tracker: Arc<Mutex<SpanTracker>>, } fn observation_create_from( trace_id: &str, observation_id: &str, span_data: &mut SpanData, parent_observation_id: Option<String>, ) -> IngestionEvent { // Expect all langfuse values to be prefixed by "langfuse." // Extract the fields from the metadata // Metadata is all values without a langfuse prefix let metadata = span_data.remaining_metadata().map(Into::into); let start_time = span_data .get("langfuse.start_time") .unwrap_or(span_data.start_time.clone()); let name = span_data.get("otel.name").unwrap_or(span_data.name.clone()); let swiftide_usage = span_data.get::<swiftide_core::chat_completion::Usage>("langfuse.usage"); IngestionEvent::new_observation_create(ObservationBody { id: Some(Some(observation_id.to_string())), trace_id: Some(Some(trace_id.to_string())), r#type: span_data .get("langfuse.type") .unwrap_or(ObservationType::Span), name: Some(Some(name)), start_time: Some(Some(start_time)), level: Some(span_data.level), parent_observation_id: Some(parent_observation_id), metadata: Some(metadata), model: Some(span_data.get("langfuse.model")), model_parameters: Some(span_data.get("langfuse.model_parameters")), input: Some(span_data.get("langfuse.input")), version: Some(span_data.get("langfuse.version")), output: Some(span_data.get("langfuse.output")), usage: swiftide_usage.map(|u| Box::new(u.into())), status_message: Some(span_data.get("langfuse.status_message")), environment: Some(span_data.get("langfuse.environment")), completion_start_time: None, end_time: None, }) } impl Default for LangfuseLayer { fn default() -> Self { let public_key = env::var("LANGFUSE_PUBLIC_KEY") .or_else(|_| env::var("LANGFUSE_INIT_PROJECT_PUBLIC_KEY")) .unwrap_or_default(); let secret_key = env::var("LANGFUSE_SECRET_KEY") .or_else(|_| env::var("LANGFUSE_INIT_PROJECT_SECRET_KEY")) .unwrap_or_default(); if public_key.is_empty() || secret_key.is_empty() { panic!( "Public key or secret key not set. Please set LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables." ); } let base_url = env::var("LANGFUSE_URL").unwrap_or_else(|_| DEFAULT_LANGFUSE_URL.to_string()); let config = Configuration { base_path: base_url.clone(), user_agent: Some("swiftide".to_string()), client: Client::new(), basic_auth: Some((public_key.clone(), Some(secret_key.clone()))), ..Default::default() }; let batch_manager = LangfuseBatchManager::new(config); batch_manager.clone().spawn(); LangfuseLayer { batch_manager: batch_manager.boxed(), span_tracker: Arc::new(Mutex::new(SpanTracker::new())), } } } impl LangfuseLayer { // Builds the layer from an existing configuration pub fn from_config(config: Configuration) -> Self { let batch_manager = LangfuseBatchManager::new(config); batch_manager.clone().spawn(); let span_tracker = Arc::new(Mutex::new(SpanTracker::new())); Self { batch_manager: batch_manager.boxed(), span_tracker, } } // Start the layer with a batch manager // // Note that the batch manager _must_ be started before using this layer. pub fn from_batch_manager(batch_manager: &LangfuseBatchManager) -> Self { let span_tracker = Arc::new(Mutex::new(SpanTracker::new())); Self { batch_manager: batch_manager.boxed(), span_tracker, } } pub async fn flush(&self) -> anyhow::Result<()> { self.batch_manager .flush() .await .context("Failed to flush")?; Ok(()) } pub async fn handle_span(&self, span_id: u64, mut span_data: SpanData) { let observation_id = span_data.observation_id.clone(); let langfuse_ty = span_data .get("langfuse.type") .unwrap_or(ObservationType::Span); { let mut spans = self.span_tracker.lock().await; spans.add_span(span_id, observation_id.clone(), langfuse_ty); } // Get parent ID if it exists let parent_id = if let Some(parent_span_id) = span_data.parent_span_id { let spans = self.span_tracker.lock().await; spans.get_span(parent_span_id).cloned().map(|(id, _)| id) } else { None }; let trace_id = self.ensure_trace_id().await; // Create the span observation let event = observation_create_from(&trace_id, &observation_id, &mut span_data, parent_id); self.batch_manager.add_event(event).await; } pub async fn handle_span_close(&self, span_id: u64) { let Some((observation_id, langfuse_type)) = self.span_tracker.lock().await.remove_span(span_id) else { return; }; let trace_id = self.ensure_trace_id().await; let event = IngestionEvent::new_observation_update(ObservationBody { id: Some(Some(observation_id.clone())), r#type: langfuse_type, trace_id: Some(Some(trace_id.clone())), end_time: Some(Some(Utc::now().to_rfc3339())), ..Default::default() }); self.batch_manager.add_event(event).await; } pub async fn ensure_trace_id(&self) -> String { let mut spans = self.span_tracker.lock().await; if let Some(id) = spans.current_trace_id.clone() { return id; } let trace_id = Uuid::new_v4().to_string(); spans.current_trace_id = Some(trace_id.clone()); let event = IngestionEvent::new_trace_create(TraceBody { id: Some(Some(trace_id.clone())), name: Some(Some(Utc::now().timestamp().to_string())), timestamp: Some(Some(Utc::now().to_rfc3339())), public: Some(Some(false)), ..Default::default() }); self.batch_manager.add_event(event).await; trace_id } pub async fn handle_record(&self, span_id: u64, metadata: serde_json::Map<String, Value>) { let Some((observation_id, langfuse_type)) = self.span_tracker.lock().await.get_span(span_id).cloned() else { return; }; let trace_id = self.ensure_trace_id().await; let metadata = SpanData::from(metadata); let remaining = metadata.remaining_metadata().map(Into::into); let swiftide_usage = metadata.get::<swiftide_core::chat_completion::Usage>("langfuse.usage"); let event = IngestionEvent::new_observation_update(ObservationBody { id: Some(Some(observation_id.clone())), trace_id: Some(Some(trace_id.clone())), r#type: langfuse_type, metadata: Some(remaining), input: Some(metadata.get("langfuse.input")), output: Some(metadata.get("langfuse.output")), model: Some(metadata.get("langfuse.model")), model_parameters: Some(metadata.get("langfuse.model_parameters")), version: Some(metadata.get("langfuse.version")), usage: swiftide_usage.map(|u| Box::new(u.into())), status_message: Some(metadata.get("langfuse.status_message")), environment: Some(metadata.get("langfuse.environment")), ..Default::default() }); self.batch_manager.add_event(event).await; } } impl<S> Layer<S> for LangfuseLayer where S: Subscriber + for<'a> LookupSpan<'a>, { fn enabled(&self, _metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool { // Enable this layer for all spans and events true } fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) { let span_id = id.into_u64(); let parent_span_id = ctx .span_scope(id) .and_then(|mut scope| scope.nth(1)) .map(|parent| parent.id().into_u64()); let mut visitor = JsonVisitor::new(); attrs.record(&mut visitor); let span_data = SpanData { observation_id: Uuid::new_v4().to_string(), name: attrs.metadata().name().to_string(), start_time: Utc::now().to_rfc3339(), level: map_level(attrs.metadata().level()), metadata: visitor.recorded_fields, parent_span_id, }; let layer = self.clone(); tokio::spawn(async move { layer.handle_span(span_id, span_data).await }); } fn on_close(&self, id: Id, _ctx: Context<'_, S>) { let span_id = id.into_u64(); let layer = self.clone(); tokio::spawn(async move { layer.handle_span_close(span_id).await }); } fn on_record(&self, span: &Id, values: &span::Record<'_>, _ctx: Context<'_, S>) { let span_id = span.into_u64(); let mut visitor = JsonVisitor::new(); values.record(&mut visitor); let metadata = visitor.recorded_fields; if !metadata.is_empty() { let layer = self.clone(); tokio::spawn(async move { layer.handle_record(span_id, metadata).await }); } } fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) { let mut visitor = JsonVisitor::new(); event.record(&mut visitor); let metadata = visitor.recorded_fields; if let Some(span_id) = ctx.lookup_current().map(|span| span.id().into_u64()) { let layer = self.clone(); tokio::spawn(async move { layer.handle_record(span_id, metadata).await }); } } } #[derive(Debug)] struct JsonVisitor { recorded_fields: serde_json::Map<String, Value>, } impl JsonVisitor { fn new() -> Self { Self { recorded_fields: serde_json::Map::new(), } } fn insert_value(&mut self, field: &Field, value: Value) { self.recorded_fields.insert(field.name().to_string(), value); } } macro_rules! record_field { ($fn_name:ident, $type:ty) => { fn $fn_name(&mut self, field: &Field, value: $type) { self.insert_value(field, Value::from(value)); } }; } impl Visit for JsonVisitor { record_field!(record_i64, i64); record_field!(record_u64, u64); record_field!(record_bool, bool); fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) { self.insert_value(field, Value::String(format!("{value:?}"))); } fn record_str(&mut self, field: &Field, value: &str) { let value = Value::from_str(value).unwrap_or_else(|_| Value::String(value.to_string())); self.insert_value(field, value); } } #[cfg(test)] mod tests { use super::*; use tokio::sync::Mutex; use tracing::{Level, subscriber::set_global_default}; use tracing_subscriber::prelude::*; #[derive(Clone)] struct InMemoryBatchManager { pub events: Arc<Mutex<Vec<crate::models::ingestion_event::IngestionEvent>>>, } #[async_trait::async_trait] impl crate::langfuse_batch_manager::BatchManagerTrait for InMemoryBatchManager { async fn add_event(&self, event: crate::models::ingestion_event::IngestionEvent) { self.events.lock().await.push(event); } async fn flush(&self) -> anyhow::Result<()> { Ok(()) } fn boxed(&self) -> Box<dyn crate::langfuse_batch_manager::BatchManagerTrait + Send + Sync> { Box::new(Self { events: Arc::clone(&self.events), }) } } #[test_log::test(tokio::test)] async fn test_generation_span_fields_are_correct_and_single_observation_created() { let events = Arc::new(Mutex::new(Vec::new())); let batch_mgr = InMemoryBatchManager { events: Arc::clone(&events), }; let langfuse_layer = LangfuseLayer { batch_manager: batch_mgr.boxed(), span_tracker: Arc::new(Mutex::new(SpanTracker::new())), }; let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::sink()); let subscriber = tracing_subscriber::Registry::default() .with(langfuse_layer) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_test_writer(), ); set_global_default(subscriber).unwrap(); let usage = swiftide_core::chat_completion::Usage { prompt_tokens: 5, completion_tokens: 9, total_tokens: 14, details: None, }; // Start a GENERATION span, record fields, and drop/end. { let span = tracing::span!( Level::INFO, "prompt", langfuse.type = "GENERATION", langfuse.input = "sample-in", langfuse.output = "sample-out", langfuse.usage = serde_json::to_string(&usage).unwrap() ); let _enter = span.enter(); // Span ends here (dropped) } // Allow async processing to complete tokio::time::sleep(std::time::Duration::from_millis(200)).await; let events = events.lock().await; // There should be one observation create (and likely one trace, but we check for GENERATION // only) let generation_events: Vec<_> = events .iter() .filter(|e| { matches!( e, crate::models::ingestion_event::IngestionEvent::ObservationCreate(_) ) }) .collect(); assert_eq!(generation_events.len(), 1); if let crate::models::ingestion_event::IngestionEvent::ObservationCreate(obs) = &generation_events[0] { let body = &obs.body; assert_eq!(body.r#type, crate::models::ObservationType::Generation); assert_eq!(body.input, Some(Some("sample-in".into()))); assert_eq!(body.output, Some(Some("sample-out".into()))); assert_eq!( body.usage .as_ref() .map(|b| serde_json::to_value(&**b).unwrap()), Some(serde_json::json!({"input": 5, "output": 9, "total": 14, "unit": "TOKENS"})) ); } else { panic!("Did not capture a GENERATION observation as expected"); } } } ================================================ FILE: swiftide-langfuse/tests/full_flow.rs ================================================ use std::sync::{Arc, Mutex}; use reqwest::Client; use swiftide_langfuse::{Configuration, LangfuseBatchManager, LangfuseLayer}; use tokio::task::yield_now; use tracing::{Level, info, span}; use tracing_subscriber::{Registry, layer::SubscriberExt}; use wiremock::{ Mock, MockServer, ResponseTemplate, matchers::{method, path}, }; #[test_log::test(tokio::test)] async fn integration_tracing_layer_sends_to_langfuse() { // Start Wiremock server let mock_server = MockServer::start().await; // Mock a successful ingestion response let response = ResponseTemplate::new(200).set_body_raw( r#"{"successes":[{"id":"abc","status":200}],"errors":[]}"#, "application/json", ); let body = Arc::new(Mutex::new(None)); let body_clone = body.clone(); Mock::given(method("POST")) .and(path("/api/public/ingestion")) .respond_with(move |req: &wiremock::Request| { let body_clone = body_clone.clone(); let body_str = String::from_utf8_lossy(&req.body).to_string(); let mut lock = body_clone.lock().unwrap(); *lock = Some(body_str); response.clone() }) .expect(1) .mount(&mock_server) .await; // Prepare Langfuse config to point to the mock server let config = Configuration { base_path: mock_server.uri(), user_agent: Some("integration-test".into()), client: Client::new(), basic_auth: Some(("PUBLIC".into(), Some("SECRET".into()))), ..Default::default() }; // Set up tracing layer let batch_manager = LangfuseBatchManager::new(config); let layer = LangfuseLayer::from_batch_manager(&batch_manager); batch_manager.clone().spawn(); // Install subscriber and layer let subscriber = Registry::default().with(layer); tracing::subscriber::with_default(subscriber, || { let span = span!( Level::INFO, "test_span", "langfuse.input" = "LANGFUSE INPUT", "langfuse.output" = "LANGFUSE OUTPUT", "langfuse.model" = "LANGFUSE MODEL", "otel.name" = "OTEL.OVERWRITE", foo = 42 ); let _enter = span.enter(); info!(bar = "baz", "Hello from integration test"); }); // Give some time for the async tasks to run yield_now().await; // Force the flush as the batch manager is not dropped yet batch_manager.flush().await.unwrap(); // Assert request received mock_server.verify().await; insta::with_settings!({ filters => vec![ // UUID v4/v5 pattern (r#""[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}""#, r#""<UUID>""#), // Improved ISO8601 datetime filter, matching both Z and offsets (r#""\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})""#, r#""<TIMESTAMP>""#), // Unix timestamp (with optional ms) (r#""\d{10,13}""#, r#""<UNIX_TIMESTAMP>""#), ] }, { insta::assert_snapshot!(body.lock().unwrap().as_ref().unwrap()) }); } ================================================ FILE: swiftide-langfuse/tests/snapshots/full_flow__integration_tracing_layer_sends_to_langfuse.snap ================================================ --- source: swiftide-langfuse/tests/full_flow.rs expression: body.lock().unwrap().as_ref().unwrap() --- {"batch":[{"body":{"id":"<UUID>","timestamp":"<TIMESTAMP>","name":"<UNIX_TIMESTAMP>","public":false},"id":"<UUID>","timestamp":"<TIMESTAMP>","type":"trace-create"},{"body":{"id":"<UUID>","traceId":"<UUID>","type":"SPAN","name":"OTEL.OVERWRITE","startTime":"<TIMESTAMP>","model":"LANGFUSE MODEL","modelParameters":null,"input":"LANGFUSE INPUT","version":null,"metadata":{"foo":42,"otel.name":"OTEL.OVERWRITE"},"output":"LANGFUSE OUTPUT","level":"DEFAULT","statusMessage":null,"parentObservationId":null,"environment":null},"id":"<UUID>","timestamp":"<TIMESTAMP>","type":"observation-create"},{"body":{"id":"<UUID>","traceId":"<UUID>","type":"SPAN","model":null,"modelParameters":null,"input":null,"version":null,"metadata":{"bar":"baz","message":"Hello from integration test"},"output":null,"statusMessage":null,"environment":null},"id":"<UUID>","timestamp":"<TIMESTAMP>","type":"observation-update"},{"body":{"id":"<UUID>","traceId":"<UUID>","type":"SPAN","endTime":"<TIMESTAMP>"},"id":"<UUID>","timestamp":"<TIMESTAMP>","type":"observation-update"}]} ================================================ FILE: swiftide-macros/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-macros" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [lib] proc-macro = true [dependencies] quote = { workspace = true } syn = { workspace = true } darling = { workspace = true } proc-macro2 = { workspace = true } convert_case = { workspace = true } # Macro dependencies anyhow.workspace = true async-trait.workspace = true serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } schemars = { workspace = true, features = ["derive"] } [dev-dependencies] pretty_assertions.workspace = true rustversion = "1.0.18" trybuild = "1.0" prettyplease = "0.2.25" insta.workspace = true swiftide = { path = "../swiftide/" } swiftide-core = { path = "../swiftide-core/" } tokio = { workspace = true, features = ["full"] } [lints] workspace = true [features] # TODO: Clean up feature flag default = ["swiftide-agents"] swiftide-agents = ["dep:serde", "dep:serde_json"] [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: swiftide-macros/src/indexing_transformer.rs ================================================ use darling::{Error, FromMeta, ast::NestedMeta}; use proc_macro2::TokenStream; use quote::quote; use syn::{Fields, Ident, ItemStruct}; #[derive(FromMeta, Default)] #[darling(default)] struct TransformerArgs { metadata_field_name: Option<String>, default_prompt_file: Option<String>, derive: DeriveOptions, } #[derive(FromMeta, Debug, Default)] #[darling(default)] struct DeriveOptions { skip_debug: bool, skip_clone: bool, skip_default: bool, } #[allow(clippy::too_many_lines)] pub(crate) fn indexing_transformer_impl(args: TokenStream, input: ItemStruct) -> TokenStream { let args = match parse_args(args) { Ok(args) => args, Err(e) => return e.write_errors(), }; let struct_name = &input.ident; let builder_name = Ident::new( &format!("{struct_name}Builder"), proc_macro2::Span::call_site(), ); let vis = &input.vis; let attrs = &input.attrs; let existing_fields = extract_existing_fields(input.fields).collect::<Vec<proc_macro2::TokenStream>>(); let metadata_field_name = match args.metadata_field_name { Some(name) => quote! { pub const NAME: &str = #name; }, None => quote! {}, }; let prompt_template_struct_attr = match &args.default_prompt_file { Some(_file) => quote! { #[builder(default = "default_prompt()")] prompt_template: hidden::Prompt, }, None => quote! {}, }; let default_prompt_fn = match &args.default_prompt_file { Some(file) => quote! { fn default_prompt() -> hidden::Prompt { include_str!(#file).into() } }, None => quote! {}, }; let derive = { let mut tokens = vec![quote! { hidden::Builder}]; if !args.derive.skip_debug { tokens.push(quote! { Debug }); } if !args.derive.skip_clone { tokens.push(quote! { Clone }); } quote! { #[derive(#(#tokens),*)] } }; let default_impl = if args.derive.skip_default { quote! {} } else { quote! { impl Default for #struct_name { fn default() -> Self { #builder_name::default().build().unwrap() } } } }; quote! { mod hidden { pub use std::sync::Arc; pub use anyhow::Result; pub use derive_builder::Builder; pub use swiftide_core::{ indexing::{IndexingDefaults}, prompt::Prompt, chat_completion::errors::LanguageModelError, SimplePrompt, Transformer, WithIndexingDefaults }; } #metadata_field_name #derive #[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] #(#attrs)* #vis struct #struct_name { #(#existing_fields)* #[builder(setter(custom), default)] client: Option<hidden::Arc<dyn hidden::SimplePrompt>>, #prompt_template_struct_attr #[builder(default)] concurrency: Option<usize>, #[builder(private, default)] indexing_defaults: Option<hidden::IndexingDefaults>, } #default_impl impl #struct_name { /// Creates a new builder for the transformer pub fn builder() -> #builder_name { #builder_name::default() } /// Build a new transformer from a client pub fn from_client(client: impl hidden::SimplePrompt + 'static) -> #builder_name { #builder_name::default().client(client).to_owned() } /// Create a new transformer from a client pub fn new(client: impl hidden::SimplePrompt + 'static) -> Self { #builder_name::default().client(client).build().unwrap() } /// Set the concurrency level for the transformer #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = Some(concurrency); self } /// Prompts either the client provided to the transformer or a default client /// provided on the indexing pipeline /// /// # Errors /// /// Gives an error if no (default) client is provided async fn prompt(&self, prompt: hidden::Prompt) -> hidden::Result<String, hidden::LanguageModelError> { if let Some(client) = &self.client { return client.prompt(prompt).await }; let Some(defaults) = &self.indexing_defaults.as_ref() else { return Err(hidden::LanguageModelError::PermanentError("No client provided".into())) }; let Some(client) = defaults.simple_prompt() else { return Err(hidden::LanguageModelError::PermanentError("No client provided".into())) }; client.prompt(prompt).await } } impl #builder_name { pub fn client(&mut self, client: impl hidden::SimplePrompt + 'static) -> &mut Self { self.client = Some(Some(hidden::Arc::new(client) as hidden::Arc<dyn hidden::SimplePrompt>)); self } } impl hidden::WithIndexingDefaults for #struct_name { fn with_indexing_defaults(&mut self, defaults: hidden::IndexingDefaults) { self.indexing_defaults = Some(defaults); } } #default_prompt_fn } } fn parse_args(args: TokenStream) -> Result<TransformerArgs, Error> { let attr_args = NestedMeta::parse_meta_list(args)?; TransformerArgs::from_list(&attr_args) } fn extract_existing_fields(fields: Fields) -> impl Iterator<Item = proc_macro2::TokenStream> { fields.into_iter().map(|field| { let field_name = &field.ident; let field_type = &field.ty; let field_vis = &field.vis; let field_attrs = &field.attrs; quote! { #(#field_attrs)* #field_vis #field_name: #field_type, } }) } #[cfg(test)] mod tests { use super::*; use quote::quote; use syn::{ItemStruct, parse_quote}; #[test] fn test_includes_doc_comments() { let input: ItemStruct = parse_quote! { /// This is a test struct pub struct TestStruct { /// This is a test field pub test_field: String, } }; let args: TokenStream = quote!(); let output = indexing_transformer_impl(args, input); let expected_output = quote! { mod hidden { pub use std::sync::Arc; pub use anyhow::Result; pub use derive_builder::Builder; pub use swiftide_core::{ indexing::{IndexingDefaults}, prompt::Prompt, chat_completion::errors::LanguageModelError, SimplePrompt, Transformer, WithIndexingDefaults }; } #[derive(hidden::Builder, Debug, Clone)] #[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] /// This is a test struct pub struct TestStruct { /// This is a test field pub test_field: String, #[builder(setter(custom), default)] client: Option<hidden::Arc<dyn hidden::SimplePrompt>>, #[builder(default)] concurrency: Option<usize>, #[builder(private, default)] indexing_defaults: Option<hidden::IndexingDefaults>, } impl Default for TestStruct { fn default() -> Self { TestStructBuilder::default().build().unwrap() } } impl TestStruct { /// Creates a new builder for the transformer pub fn builder() -> TestStructBuilder { TestStructBuilder::default() } /// Build a new transformer from a client pub fn from_client(client: impl hidden::SimplePrompt + 'static) -> TestStructBuilder { TestStructBuilder::default().client(client).to_owned() } /// Create a new transformer from a client pub fn new(client: impl hidden::SimplePrompt + 'static) -> Self { TestStructBuilder::default().client(client).build().unwrap() } /// Set the concurrency level for the transformer #[must_use] pub fn with_concurrency(mut self, concurrency: usize) -> Self { self.concurrency = Some(concurrency); self } /// Prompts either the client provided to the transformer or a default client /// provided on the indexing pipeline /// /// # Errors /// /// Gives an error if no (default) client is provided async fn prompt(&self, prompt: hidden::Prompt) -> hidden::Result<String, hidden::LanguageModelError> { if let Some(client) = &self.client { return client.prompt(prompt).await }; let Some(defaults) = &self.indexing_defaults.as_ref() else { return Err(hidden::LanguageModelError::PermanentError("No client provided".into())) }; let Some(client) = defaults.simple_prompt() else { return Err(hidden::LanguageModelError::PermanentError("No client provided".into())) }; client.prompt(prompt).await } } impl TestStructBuilder { pub fn client(&mut self, client: impl hidden::SimplePrompt + 'static) -> &mut Self { self.client = Some(Some(hidden::Arc::new(client) as hidden::Arc<dyn hidden::SimplePrompt>)); self } } impl hidden::WithIndexingDefaults for TestStruct { fn with_indexing_defaults(&mut self, defaults: hidden::IndexingDefaults) { self.indexing_defaults = Some(defaults); } } }; assert_eq!(output.to_string(), expected_output.to_string()); } } ================================================ FILE: swiftide-macros/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] //! This crate provides macros for generating boilerplate code //! for indexing transformers use proc_macro::TokenStream; mod indexing_transformer; #[cfg(test)] mod test_utils; mod tool; use indexing_transformer::indexing_transformer_impl; use syn::{DeriveInput, ItemFn, ItemStruct, parse_macro_input}; use tool::{tool_attribute_impl, tool_derive_impl}; /// Generates boilerplate for an indexing transformer. #[proc_macro_attribute] pub fn indexing_transformer(args: TokenStream, input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as ItemStruct); indexing_transformer_impl(args.into(), input).into() } #[proc_macro_attribute] /// Creates a `Tool` from an async function. /// /// # Example /// ```ignore /// #[tool(description = "Searches code", param(name = "code_query", description = "The code query"))] /// pub async fn search_code(context: &dyn AgentContext, code_query: &str) -> Result<ToolOutput, /// ToolError> { /// Ok("hello".into()) /// } /// /// // The tool can then be used with agents: /// Agent::builder().tools([search_code()]) /// /// // Or /// /// Agent::builder().tools([SearchCode::default()]) /// ``` pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as ItemFn); tool_attribute_impl(&args.into(), &input).into() } /// Derive `Tool` on a struct. /// /// Useful if your structs have internal state and you want to use it in your tool. /// /// # Example /// ```ignore /// #[derive(Clone, Tool)] /// #[tool(description = "Searches code", param(name = "code_query", description = "The code query"))] /// pub struct SearchCode { /// search_command: String /// } /// /// impl SearchCode { /// pub async fn search_code(&self, context: &dyn AgentContext, code_query: &str) -> Result<ToolOutput, ToolError> { /// context.exec_cmd(&self.search_command.into()).await.map(Into::into) /// } /// } /// ``` #[proc_macro_derive(Tool, attributes(tool))] pub fn derive_tool(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); match tool_derive_impl(&input) { Ok(tokens) => tokens.into(), Err(err) => err.into_compile_error().into(), } } ================================================ FILE: swiftide-macros/src/test_utils.rs ================================================ pub fn pretty_macro_output(item: &proc_macro2::TokenStream) -> String { let file = syn::parse_file(&item.to_string()) .unwrap_or_else(|_| panic!("Failed to parse token stream: {}", &item.to_string())); prettyplease::unparse(&file) } // Add a macro that pretty compares two token streams using the above called `assert_ts_eq!` #[macro_export] macro_rules! assert_ts_eq { ($left:expr, $right:expr) => {{ let left_pretty = $crate::test_utils::pretty_macro_output(&$left); let right_pretty = $crate::test_utils::pretty_macro_output(&$right); pretty_assertions::assert_eq!(left_pretty, right_pretty); }}; } ================================================ FILE: swiftide-macros/src/tool/args.rs ================================================ use convert_case::{Case, Casing as _}; use darling::{Error, FromMeta, ast::NestedMeta}; use proc_macro2::TokenStream; use quote::{ToTokens as _, quote}; use syn::{FnArg, Ident, ItemFn, Pat, PatType, parse_quote}; #[derive(FromMeta, Default, Debug)] pub struct ToolArgs { #[darling(default)] /// Name of the tool /// Defaults to the underscored version of the function name or struct name: String, /// Name of the function to call /// Defaults to the underscored version of the function name or struct #[darling(default)] fn_name: String, /// Description of the tool description: Description, /// Parameters the tool can take #[darling(multiple, rename = "param")] params: Vec<ParamOptions>, } #[derive(FromMeta, Debug, Default)] #[darling(default)] pub struct ParamOptions { pub name: String, pub description: String, /// Backwards compatibility: optional JSON type hint (string based) pub json_type: Option<String>, /// Explicit rust type override parsed from the attribute pub rust_type: Option<syn::Type>, pub required: Option<bool>, #[darling(skip)] pub resolved_type: Option<syn::Type>, } #[derive(Debug)] pub enum Description { Literal(String), Path(syn::Path), } impl Default for Description { fn default() -> Self { Description::Literal(String::new()) } } impl FromMeta for Description { fn from_expr(expr: &syn::Expr) -> darling::Result<Self> { match expr { syn::Expr::Lit(lit) => { if let syn::Lit::Str(s) = &lit.lit { Ok(Description::Literal(s.value())) } else { Err(Error::unsupported_format( "expected a string literal or a const", )) } } syn::Expr::Path(path) => Ok(Description::Path(path.path.clone())), _ => Err(Error::unsupported_format( "expected a string literal or a const", )), } } } impl ToolArgs { pub fn try_from_attribute_input(input: &ItemFn, args: TokenStream) -> Result<Self, Error> { validate_first_argument_is_agent_context(input)?; let attr_args = NestedMeta::parse_meta_list(args)?; let mut args = ToolArgs::from_list(&attr_args)?; for arg in input.sig.inputs.iter().skip(1) { if let FnArg::Typed(PatType { pat, ty, .. }) = arg && let Pat::Ident(ident) = &**pat { let ty = as_owned_ty(ty); if let Some(param) = args.params.iter_mut().find(|p| ident.ident == p.name) { param.rust_type = Some(ty); } } } args.infer_param_types()?; validate_spec_and_fn_args_match(&args, input)?; args.with_name_from_ident(&input.sig.ident); Ok(args) } pub fn infer_param_types(&mut self) -> Result<(), Error> { for param in &mut self.params { let mut ty = if let Some(ty) = param.rust_type.clone() { ty } else if let Some(json_type) = ¶m.json_type { json_type_to_rust_type(json_type) } else { syn::parse_quote! { String } }; let is_option = is_option_type(&ty); match param.required { Some(true) if is_option => { return Err(Error::custom(format!( "The parameter {} is marked as required but has an optional type", param.name ))); } Some(false) if !is_option => { ty = wrap_type_in_option(ty); } None if is_option => { param.required = Some(false); } None => { param.required = Some(true); } _ => {} } param.resolved_type = Some(ty); } Ok(()) } pub fn with_name_from_ident(&mut self, ident: &syn::Ident) { if self.name.is_empty() { self.name = ident.to_string().to_case(Case::Snake); } if self.fn_name.is_empty() { self.fn_name = ident.to_string().to_case(Case::Snake); } } pub fn tool_name(&self) -> &str { &self.name } pub fn fn_name(&self) -> &str { &self.fn_name } pub fn tool_description(&self) -> &Description { &self.description } pub fn tool_params(&self) -> &[ParamOptions] { &self.params } pub fn derive_invoke_args(&self) -> Vec<TokenStream> { self.params .iter() .map(|param| { let ident = syn::Ident::new(¶m.name, proc_macro2::Span::call_site()); if param.should_pass_owned() { quote! { args.#ident } } else { quote! { &args.#ident } } }) .collect() } pub fn args_struct(&self) -> TokenStream { if self.params.is_empty() { return quote! {}; } let mut fields = Vec::new(); for param in &self.params { let ty = param .resolved_type .as_ref() .expect("parameter types should be resolved"); let ident = syn::Ident::new(¶m.name, proc_macro2::Span::call_site()); fields.push(quote! { pub #ident: #ty }); } let args_struct_ident = self.args_struct_ident(); quote! { #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct #args_struct_ident { #(#fields),* } } } pub fn args_struct_ident(&self) -> Ident { syn::Ident::new( &format!("{}Args", self.name.to_case(Case::Pascal)), proc_macro2::Span::call_site(), ) } } fn validate_spec_and_fn_args_match(tool_args: &ToolArgs, item_fn: &ItemFn) -> Result<(), Error> { let mut found_spec_arg_names = tool_args .params .iter() .map(|param| param.name.clone()) .collect::<Vec<_>>(); found_spec_arg_names.sort(); let mut seen_arg_names = vec![]; item_fn.sig.inputs.iter().skip(1).for_each(|arg| { if let FnArg::Typed(PatType { pat, .. }) = arg && let Pat::Ident(ident) = &**pat { seen_arg_names.push(ident.ident.to_string()); } }); seen_arg_names.sort(); let mut errors = Error::accumulator(); if found_spec_arg_names != seen_arg_names { let missing_args = found_spec_arg_names .iter() .filter(|name| !seen_arg_names.contains(name)) .collect::<Vec<_>>(); let missing_params = seen_arg_names .iter() .filter(|name| !found_spec_arg_names.contains(name)) .collect::<Vec<_>>(); if !missing_args.is_empty() { errors.push(Error::custom(format!( "The following parameters are missing from the function signature: {missing_args:?}" ))); } if !missing_params.is_empty() { errors.push(Error::custom(format!( "The following parameters are missing from the spec: {missing_params:?}" ))); } } errors.finish()?; Ok(()) } fn json_type_to_rust_type(json_type: &str) -> syn::Type { match json_type.to_ascii_lowercase().as_str() { "number" => syn::parse_quote! { usize }, "boolean" => syn::parse_quote! { bool }, "array" => syn::parse_quote! { Vec<String> }, "object" => syn::parse_quote! { ::serde_json::Value }, // default to string if nothing is specified _ => syn::parse_quote! { String }, } } fn is_option_type(ty: &syn::Type) -> bool { if let syn::Type::Path(type_path) = ty { if type_path.qself.is_some() { return false; } return type_path .path .segments .last() .is_some_and(|segment| segment.ident == "Option"); } false } fn wrap_type_in_option(ty: syn::Type) -> syn::Type { if is_option_type(&ty) { ty } else { syn::parse_quote! { Option<#ty> } } } fn as_owned_ty(ty: &syn::Type) -> syn::Type { if let syn::Type::Reference(r) = ty { if let syn::Type::Path(p) = &*r.elem { if p.path.is_ident("str") { return parse_quote!(String); } // Does this happen? if p.path.is_ident("Vec") && let syn::PathArguments::AngleBracketed(args) = &p.path.segments[0].arguments && let syn::GenericArgument::Type(ty) = args.args.first().unwrap() { let inner = as_owned_ty(ty); return parse_quote!(Vec<#inner>); } if let Some(last_segment) = p.path.segments.last() && last_segment.ident.to_string().as_str() == "Option" && let syn::PathArguments::AngleBracketed(generics) = &last_segment.arguments && let Some(syn::GenericArgument::Type(inner_ty)) = generics.args.first() { let inner_ty = as_owned_ty(inner_ty); return parse_quote!(Option<#inner_ty>); } return parse_quote!(String); } if let syn::Type::Slice(slice_type) = &*r.elem { // slice_type.elem is T. We'll replace with Vec<T>. let elem = &slice_type.elem; return parse_quote!(Vec<#elem>); } panic!("Unsupported reference type"); } else { ty.to_owned() } } fn is_vec_type(ty: &syn::Type) -> bool { if let syn::Type::Path(type_path) = ty { if type_path.qself.is_some() { return false; } return type_path .path .segments .last() .is_some_and(|segment| segment.ident == "Vec"); } false } impl ParamOptions { fn should_pass_owned(&self) -> bool { self.resolved_type.as_ref().is_some_and(is_vec_type) } } fn validate_first_argument_is_agent_context(input_fn: &ItemFn) -> Result<(), Error> { let expected_first_arg = quote! { &dyn AgentContext }; let error_msg = "The first argument must be `&dyn AgentContext`"; if let Some(FnArg::Typed(first_arg)) = input_fn.sig.inputs.first() { if first_arg.ty.to_token_stream().to_string() != expected_first_arg.to_string() { return Err(Error::custom(error_msg).with_span(&first_arg.ty)); } } else { return Err(Error::custom(error_msg).with_span(&input_fn.sig)); } Ok(()) } ================================================ FILE: swiftide-macros/src/tool/mod.rs ================================================ #![allow(clippy::used_underscore_binding)] #![allow(clippy::needless_continue)] use args::ToolArgs; use darling::{Error, FromDeriveInput}; use proc_macro2::TokenStream; use quote::quote; use syn::{DeriveInput, FnArg, ItemFn, Pat, PatType, parse_quote}; mod args; mod tool_spec; mod wrapped; #[allow(clippy::too_many_lines)] pub(crate) fn tool_attribute_impl(input_args: &TokenStream, input: &ItemFn) -> TokenStream { let tool_args = match ToolArgs::try_from_attribute_input(input, input_args.clone()) { Ok(args) => args, Err(e) => return e.write_errors(), }; let fn_name = &input.sig.ident; let args_struct = tool_args.args_struct(); let args_struct_ident = tool_args.args_struct_ident(); let arg_names = input .sig .inputs .iter() .skip(1) .filter_map(|arg| { if let FnArg::Typed(PatType { pat, ty, .. }) = arg { if let Pat::Ident(ident) = &**pat { // If the argument is a reference, we need to reference the quote as well if let syn::Type::Reference(_) = &**ty { Some(quote! { &args.#ident }) } else { Some(quote! { args.#ident }) } } else { None } } else { None } }) .collect::<Vec<_>>(); let tool_name = tool_args.tool_name(); let tool_struct = wrapped::struct_name(input); let wrapped_fn = wrapped::wrap_tool_fn(input); let tool_spec = tool_spec::tool_spec(&tool_args); let invoke_body = if arg_names.is_empty() { quote! { return self.#fn_name(agent_context).await; } } else { quote! { let Some(args) = tool_call.args() else { return Err(::swiftide::chat_completion::errors::ToolError::MissingArguments(format!("No arguments provided for {}", #tool_name).into())) }; let args: #args_struct_ident = ::swiftide::reexports::serde_json::from_str(&args)?; return self.#fn_name(agent_context, #(#arg_names),*).await; } }; let boxed_from = boxed_from(&tool_struct, &parse_quote!()); quote! { #args_struct #wrapped_fn #[::swiftide::reexports::async_trait::async_trait] impl ::swiftide::chat_completion::Tool for #tool_struct { async fn invoke(&self, agent_context: &dyn ::swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall) -> ::std::result::Result<::swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError> { #invoke_body } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { #tool_name.into() } fn tool_spec(&self) -> ::swiftide::chat_completion::ToolSpec { #tool_spec } } #boxed_from } } #[allow(clippy::needless_continue)] #[derive(FromDeriveInput)] #[darling(attributes(tool), supports(struct_any), and_then = ToolDerive::update_defaults, forward_attrs(allow, doc, cfg))] struct ToolDerive { ident: syn::Ident, #[allow(dead_code)] attrs: Vec<syn::Attribute>, #[darling(flatten)] tool: ToolArgs, } impl ToolDerive { pub fn update_defaults(mut self) -> Result<Self, Error> { self.tool.with_name_from_ident(&self.ident); self.tool.infer_param_types()?; Ok(self) } } pub(crate) fn tool_derive_impl(input: &DeriveInput) -> syn::Result<TokenStream> { let parsed: ToolDerive = ToolDerive::from_derive_input(input)?; let struct_ident = &parsed.ident; let expected_fn_name = parsed.tool.fn_name(); let expected_fn_ident = syn::Ident::new(expected_fn_name, struct_ident.span()); let invoke_tool_args = parsed.tool.derive_invoke_args(); let args_struct_ident = parsed.tool.args_struct_ident(); let args_struct = parsed.tool.args_struct(); let invoke_body = if invoke_tool_args.is_empty() { quote! { return self.#expected_fn_ident(agent_context).await } } else { quote! { let Some(args) = tool_call.args() else { return Err(::swiftide::chat_completion::errors::ToolError::MissingArguments(format!("No arguments provided for {}", #expected_fn_name).into())) }; let args: #args_struct_ident = ::swiftide::reexports::serde_json::from_str(&args)?; return self.#expected_fn_ident(agent_context, #(#invoke_tool_args),*).await; } }; let tool_spec = tool_spec::tool_spec(&parsed.tool); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); // Arg should be, if empty None, else Some(&args) let boxed_from = boxed_from(struct_ident, &input.generics); Ok(quote! { #args_struct #[async_trait::async_trait] impl #impl_generics swiftide::chat_completion::Tool for #struct_ident #ty_generics #where_clause { async fn invoke(&self, agent_context: &dyn swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall) -> std::result::Result<swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError> { #invoke_body } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { #expected_fn_name.into() } fn tool_spec(&self) -> swiftide::chat_completion::ToolSpec { #tool_spec } } #boxed_from }) } fn boxed_from(struct_ident: &syn::Ident, generics: &syn::Generics) -> TokenStream { if !generics.params.is_empty() { return quote!(); } let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let lt_ident = if let Some(other_lifetime) = generics.lifetimes().next() { let lifetime = &other_lifetime.lifetime; quote!(+ #lifetime) } else { quote!() }; quote! { impl #impl_generics From<#struct_ident #ty_generics> for Box<dyn ::swiftide::chat_completion::Tool #lt_ident> #where_clause { fn from(val: #struct_ident) -> Self { Box::new(val) as Box<dyn ::swiftide::chat_completion::Tool> } } } } #[cfg(test)] mod tests { use super::*; use quote::quote; use syn::{ItemFn, parse_quote}; #[test] fn test_snapshot_single_arg() { let args = quote! { description = "Hello world tool", param( name = "code_query", description = "my param description" ) }; let input: ItemFn = parse_quote! { pub async fn search_code(context: &dyn AgentContext, code_query: &str) -> Result<ToolOutput, ToolError> { return Ok("hello".into()) } }; let output = tool_attribute_impl(&args, &input); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } #[test] fn test_snapshot_single_arg_option() { let args = quote! { description = "Hello world tool", param( name = "code_query", description = "my param description" ) }; let input: ItemFn = parse_quote! { pub async fn search_code(context: &dyn AgentContext, code_query: &Option<String>) -> Result<ToolOutput, ToolError> { return Ok("hello".into()) } }; let output = tool_attribute_impl(&args, &input); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } #[test] fn test_snapshot_multiple_args() { let args = quote! { description = "Hello world tool", param( name = "code_query", description = "my param description" ), param( name = "other", description = "my param description" ) }; let input: ItemFn = parse_quote! { pub async fn search_code(context: &dyn AgentContext, code_query: &str, other: &str) -> Result<ToolOutput> { return Ok("hello".into()) } }; let output = tool_attribute_impl(&args, &input); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } #[test] fn test_snapshot_derive() { let input: DeriveInput = parse_quote! { #[tool(description="Hello derive")] pub struct HelloDerive { my_thing: String } }; let output = tool_derive_impl(&input).unwrap(); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } #[test] fn test_snapshot_derive_with_args() { let input: DeriveInput = parse_quote! { #[tool(description="Hello derive", param(name="test", description="test param"))] pub struct HelloDerive { my_thing: String } }; let output = tool_derive_impl(&input).unwrap(); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } #[test] fn test_snapshot_derive_with_option() { let input: DeriveInput = parse_quote! { #[tool(description="Hello derive", param(name="test", description="test param", required = false))] pub struct HelloDerive { my_thing: String } }; let output = tool_derive_impl(&input).unwrap(); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } #[test] fn test_snapshot_derive_with_lifetime() { let input: DeriveInput = parse_quote! { #[tool(description="Hello derive", param(name="test", description="test param"))] pub struct HelloDerive<'a> { my_thing: &'a str, } }; let output = tool_derive_impl(&input).unwrap(); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } #[test] fn test_snapshot_derive_with_generics() { let input: DeriveInput = parse_quote! { #[tool(description="Hello derive", param(name="test", description="test param"))] pub struct HelloDerive<S: Send + Sync + Clone> { my_thing: S, } }; let output = tool_derive_impl(&input).unwrap(); insta::assert_snapshot!(crate::test_utils::pretty_macro_output(&output)); } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__simple_tool.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- mod hidden { pub use swiftide_agents::{Tool, AgentContext}; pub use anyhow::{bail, Result}; pub use swiftide_core::chat_completion::{JsonSpec, ToolOutput}; pub use async_trait::async_trait; } #[derive(serde::Serialize, serde::Deserialize)] struct SearchCodeArgs<'a> { pub code_query: &'a str, } #[derive(Clone)] struct SearchCode {} pub fn search_code() -> SearchCode { SearchCode {} } impl SearchCode { pub async fn search_code( &self, context: &dyn AgentContext, code_query: &str, ) -> Result<ToolOutput> { return Ok("hello".into()); } } #[hidden::async_trait] impl hidden::Tool for SearchCode { async fn invoke( &self, agent_context: &dyn hidden::AgentContext, raw_args: Option<&str>, ) -> hidden::Result<hidden::ToolOutput> { let Some(args) = raw_args else { hidden::bail!("No arguments provided for {}", "search_code") }; let args: SearchCodeArgs = serde_json::from_str(&args)?; return self.search_code(agent_context, args.code_query).await; } fn name(&self) -> &'static str { "search_code" } fn json_spec(&self) -> hidden::JsonSpec { "{\n \"description\": \"Hello world tool\",\n \"name\": \"search_code\",\n \"parameters\": {\n \"my param\": {\n \"description\": \"my param description\",\n \"type\": \"string\"\n }\n }\n}" } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_derive.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[async_trait::async_trait] impl swiftide::chat_completion::Tool for HelloDerive { async fn invoke( &self, agent_context: &dyn swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> std::result::Result< swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { return self.hello_derive(agent_context).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "hello_derive".into() } fn tool_spec(&self) -> swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("hello_derive") .description("Hello derive") .build() .unwrap() } } impl From<HelloDerive> for Box<dyn ::swiftide::chat_completion::Tool> { fn from(val: HelloDerive) -> Self { Box::new(val) as Box<dyn ::swiftide::chat_completion::Tool> } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_derive_with_args.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct HelloDeriveArgs { pub test: String, } #[async_trait::async_trait] impl swiftide::chat_completion::Tool for HelloDerive { async fn invoke( &self, agent_context: &dyn swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> std::result::Result< swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { let Some(args) = tool_call.args() else { return Err( ::swiftide::chat_completion::errors::ToolError::MissingArguments( format!("No arguments provided for {}", "hello_derive").into(), ), ) }; let args: HelloDeriveArgs = ::swiftide::reexports::serde_json::from_str(&args)?; return self.hello_derive(agent_context, &args.test).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "hello_derive".into() } fn tool_spec(&self) -> swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("hello_derive") .description("Hello derive") .parameters_schema( ::swiftide::reexports::schemars::schema_for!(HelloDeriveArgs), ) .build() .unwrap() } } impl From<HelloDerive> for Box<dyn ::swiftide::chat_completion::Tool> { fn from(val: HelloDerive) -> Self { Box::new(val) as Box<dyn ::swiftide::chat_completion::Tool> } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_derive_with_generics.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct HelloDeriveArgs { pub test: String, } #[async_trait::async_trait] impl<S: Send + Sync + Clone> swiftide::chat_completion::Tool for HelloDerive<S> { async fn invoke( &self, agent_context: &dyn swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> std::result::Result< swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { let Some(args) = tool_call.args() else { return Err( ::swiftide::chat_completion::errors::ToolError::MissingArguments( format!("No arguments provided for {}", "hello_derive").into(), ), ) }; let args: HelloDeriveArgs = ::swiftide::reexports::serde_json::from_str(&args)?; return self.hello_derive(agent_context, &args.test).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "hello_derive".into() } fn tool_spec(&self) -> swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("hello_derive") .description("Hello derive") .parameters_schema( ::swiftide::reexports::schemars::schema_for!(HelloDeriveArgs), ) .build() .unwrap() } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_derive_with_lifetime.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct HelloDeriveArgs { pub test: String, } #[async_trait::async_trait] impl<'a> swiftide::chat_completion::Tool for HelloDerive<'a> { async fn invoke( &self, agent_context: &dyn swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> std::result::Result< swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { let Some(args) = tool_call.args() else { return Err( ::swiftide::chat_completion::errors::ToolError::MissingArguments( format!("No arguments provided for {}", "hello_derive").into(), ), ) }; let args: HelloDeriveArgs = ::swiftide::reexports::serde_json::from_str(&args)?; return self.hello_derive(agent_context, &args.test).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "hello_derive".into() } fn tool_spec(&self) -> swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("hello_derive") .description("Hello derive") .parameters_schema( ::swiftide::reexports::schemars::schema_for!(HelloDeriveArgs), ) .build() .unwrap() } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_derive_with_option.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct HelloDeriveArgs { pub test: Option<String>, } #[async_trait::async_trait] impl swiftide::chat_completion::Tool for HelloDerive { async fn invoke( &self, agent_context: &dyn swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> std::result::Result< swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { let Some(args) = tool_call.args() else { return Err( ::swiftide::chat_completion::errors::ToolError::MissingArguments( format!("No arguments provided for {}", "hello_derive").into(), ), ) }; let args: HelloDeriveArgs = ::swiftide::reexports::serde_json::from_str(&args)?; return self.hello_derive(agent_context, &args.test).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "hello_derive".into() } fn tool_spec(&self) -> swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("hello_derive") .description("Hello derive") .parameters_schema( ::swiftide::reexports::schemars::schema_for!(HelloDeriveArgs), ) .build() .unwrap() } } impl From<HelloDerive> for Box<dyn ::swiftide::chat_completion::Tool> { fn from(val: HelloDerive) -> Self { Box::new(val) as Box<dyn ::swiftide::chat_completion::Tool> } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_multiple_args.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct SearchCodeArgs { pub code_query: String, pub other: String, } #[derive(Clone, Default)] pub struct SearchCode {} pub fn search_code() -> Box<dyn ::swiftide::chat_completion::Tool> { Box::new(SearchCode {}) as Box<dyn ::swiftide::chat_completion::Tool> } impl SearchCode { pub async fn search_code( &self, context: &dyn AgentContext, code_query: &str, other: &str, ) -> Result<ToolOutput> { return Ok("hello".into()); } } #[::swiftide::reexports::async_trait::async_trait] impl ::swiftide::chat_completion::Tool for SearchCode { async fn invoke( &self, agent_context: &dyn ::swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> ::std::result::Result< ::swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { let Some(args) = tool_call.args() else { return Err( ::swiftide::chat_completion::errors::ToolError::MissingArguments( format!("No arguments provided for {}", "search_code").into(), ), ) }; let args: SearchCodeArgs = ::swiftide::reexports::serde_json::from_str(&args)?; return self.search_code(agent_context, &args.code_query, &args.other).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "search_code".into() } fn tool_spec(&self) -> ::swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("search_code") .description("Hello world tool") .parameters_schema( ::swiftide::reexports::schemars::schema_for!(SearchCodeArgs), ) .build() .unwrap() } } impl From<SearchCode> for Box<dyn ::swiftide::chat_completion::Tool> { fn from(val: SearchCode) -> Self { Box::new(val) as Box<dyn ::swiftide::chat_completion::Tool> } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_single_arg.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct SearchCodeArgs { pub code_query: String, } #[derive(Clone, Default)] pub struct SearchCode {} pub fn search_code() -> Box<dyn ::swiftide::chat_completion::Tool> { Box::new(SearchCode {}) as Box<dyn ::swiftide::chat_completion::Tool> } impl SearchCode { pub async fn search_code( &self, context: &dyn AgentContext, code_query: &str, ) -> Result<ToolOutput, ToolError> { return Ok("hello".into()); } } #[::swiftide::reexports::async_trait::async_trait] impl ::swiftide::chat_completion::Tool for SearchCode { async fn invoke( &self, agent_context: &dyn ::swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> ::std::result::Result< ::swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { let Some(args) = tool_call.args() else { return Err( ::swiftide::chat_completion::errors::ToolError::MissingArguments( format!("No arguments provided for {}", "search_code").into(), ), ) }; let args: SearchCodeArgs = ::swiftide::reexports::serde_json::from_str(&args)?; return self.search_code(agent_context, &args.code_query).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "search_code".into() } fn tool_spec(&self) -> ::swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("search_code") .description("Hello world tool") .parameters_schema( ::swiftide::reexports::schemars::schema_for!(SearchCodeArgs), ) .build() .unwrap() } } impl From<SearchCode> for Box<dyn ::swiftide::chat_completion::Tool> { fn from(val: SearchCode) -> Self { Box::new(val) as Box<dyn ::swiftide::chat_completion::Tool> } } ================================================ FILE: swiftide-macros/src/tool/snapshots/swiftide_macros__tool__tests__snapshot_single_arg_option.snap ================================================ --- source: swiftide-macros/src/tool/mod.rs expression: "crate::test_utils::pretty_macro_output(&output)" --- #[derive( ::swiftide::reexports::serde::Serialize, ::swiftide::reexports::serde::Deserialize, ::swiftide::reexports::schemars::JsonSchema, Debug )] #[schemars(crate = "::swiftide::reexports::schemars", deny_unknown_fields)] pub struct SearchCodeArgs { pub code_query: Option<String>, } #[derive(Clone, Default)] pub struct SearchCode {} pub fn search_code() -> Box<dyn ::swiftide::chat_completion::Tool> { Box::new(SearchCode {}) as Box<dyn ::swiftide::chat_completion::Tool> } impl SearchCode { pub async fn search_code( &self, context: &dyn AgentContext, code_query: &Option<String>, ) -> Result<ToolOutput, ToolError> { return Ok("hello".into()); } } #[::swiftide::reexports::async_trait::async_trait] impl ::swiftide::chat_completion::Tool for SearchCode { async fn invoke( &self, agent_context: &dyn ::swiftide::traits::AgentContext, tool_call: &swiftide::chat_completion::ToolCall, ) -> ::std::result::Result< ::swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError, > { let Some(args) = tool_call.args() else { return Err( ::swiftide::chat_completion::errors::ToolError::MissingArguments( format!("No arguments provided for {}", "search_code").into(), ), ) }; let args: SearchCodeArgs = ::swiftide::reexports::serde_json::from_str(&args)?; return self.search_code(agent_context, &args.code_query).await; } fn name<'TOOL>(&'TOOL self) -> std::borrow::Cow<'TOOL, str> { "search_code".into() } fn tool_spec(&self) -> ::swiftide::chat_completion::ToolSpec { swiftide::chat_completion::ToolSpec::builder() .name("search_code") .description("Hello world tool") .parameters_schema( ::swiftide::reexports::schemars::schema_for!(SearchCodeArgs), ) .build() .unwrap() } } impl From<SearchCode> for Box<dyn ::swiftide::chat_completion::Tool> { fn from(val: SearchCode) -> Self { Box::new(val) as Box<dyn ::swiftide::chat_completion::Tool> } } ================================================ FILE: swiftide-macros/src/tool/tool_spec.rs ================================================ use proc_macro2::TokenStream; use quote::quote; use super::args::{Description, ToolArgs}; pub fn tool_spec(args: &ToolArgs) -> TokenStream { let tool_name = args.tool_name(); let description = match &args.tool_description() { Description::Literal(description) => quote! { #description }, Description::Path(path) => quote! { #path }, }; let builder = quote! { swiftide::chat_completion::ToolSpec::builder() .name(#tool_name) .description(#description) }; if args.tool_params().is_empty() { quote! { #builder.build().unwrap() } } else { let args_struct_ident = args.args_struct_ident(); quote! { #builder .parameters_schema(::swiftide::reexports::schemars::schema_for!(#args_struct_ident)) .build() .unwrap() } } } ================================================ FILE: swiftide-macros/src/tool/wrapped.rs ================================================ use proc_macro2::TokenStream; use quote::quote; use syn::{Ident, ItemFn}; pub(crate) fn struct_name(input: &ItemFn) -> Ident { let struct_name_str = input .sig .ident .to_string() .split('_') // Split by underscores .map(|s| { let mut chars = s.chars(); chars .next() .map(|c| c.to_ascii_uppercase()) .into_iter() .collect::<String>() + chars.as_str() }) .collect::<String>(); Ident::new(&struct_name_str, input.sig.ident.span()) } pub(crate) fn wrap_tool_fn(input: &ItemFn) -> TokenStream { let fn_name = &input.sig.ident; let fn_args = &input.sig.inputs; let fn_body = &input.block; let fn_output = &input.sig.output; let struct_name = struct_name(input); let fn_args = fn_args.iter(); quote! { #[derive(Clone, Default)] pub struct #struct_name {} pub fn #fn_name() -> Box<dyn ::swiftide::chat_completion::Tool> { Box::new(#struct_name {}) as Box<dyn ::swiftide::chat_completion::Tool> } impl #struct_name { pub async fn #fn_name(&self, #(#fn_args),*) #fn_output #fn_body } } } #[cfg(test)] mod tests { use crate::assert_ts_eq; use super::*; use quote::quote; use syn::{ItemFn, parse_quote}; #[test] fn test_wrap_tool_fn() { let input: ItemFn = parse_quote! { pub async fn search_code(context: &dyn swiftide::traits::AgentContext, code_query: &str) -> std::result::Result<swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError> { return Ok("hello".into()) } }; let output = wrap_tool_fn(&input); let expected = quote! { #[derive(Clone, Default)] pub struct SearchCode {} pub fn search_code() -> Box<dyn ::swiftide::chat_completion::Tool> { Box::new(SearchCode {}) as Box<dyn ::swiftide::chat_completion::Tool> } impl SearchCode { pub async fn search_code(&self, context: &dyn swiftide::traits::AgentContext, code_query: &str) -> std::result::Result<swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError> { return Ok("hello".into()) } } }; assert_ts_eq!(&output, &expected); } #[test] fn test_wrap_multiple_args() { let input: ItemFn = parse_quote! { pub async fn search_code(context: &dyn swiftide::traits::AgentContext, code_query: &str, other_arg: &str) -> std::result::Result<swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError> { return Ok("hello".into()) } }; let output = wrap_tool_fn(&input); let expected = quote! { #[derive(Clone, Default)] pub struct SearchCode {} pub fn search_code() -> Box<dyn ::swiftide::chat_completion::Tool> { Box::new(SearchCode {}) as Box<dyn ::swiftide::chat_completion::Tool> } impl SearchCode { pub async fn search_code(&self, context: &dyn swiftide::traits::AgentContext, code_query: &str, other_arg: &str) -> std::result::Result<swiftide::chat_completion::ToolOutput, ::swiftide::chat_completion::errors::ToolError> { return Ok("hello".into()) } } }; assert_ts_eq!(&output, &expected); } } ================================================ FILE: swiftide-macros/tests/tool/tool_derive_missing_description.rs ================================================ use swiftide::chat_completion::{errors::ToolError, ToolOutput}; use swiftide::traits::AgentContext; use swiftide_macros::Tool; #[derive(Clone, Tool)] struct MyToolNoArgs { test: String, } impl MyToolNoArgs { async fn my_tool_no_args( &self, _agent_context: &dyn AgentContext, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_derive_missing_description.stderr ================================================ error: Missing field `description` --> tests/tool/tool_derive_missing_description.rs:5:17 | 5 | #[derive(Clone, Tool)] | ^^^^ | = note: this error originates in the derive macro `Tool` (in Nightly builds, run with -Z macro-backtrace for more info) ================================================ FILE: swiftide-macros/tests/tool/tool_derive_pass.rs ================================================ #![allow(unused_variables)] use swiftide::chat_completion::{errors::ToolError, ToolOutput}; use swiftide::traits::AgentContext; use swiftide_macros::Tool; #[derive(Clone, Tool)] #[tool( description = "Hello tool", param(name = "test", description = "My param") )] struct MyTool { test: String, } impl MyTool { async fn my_tool( &self, agent_context: &dyn AgentContext, test: &str, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {test}").into()) } } #[derive(Clone, Tool)] #[tool( description = "Hello tool", param(name = "test", description = "My param"), param(name = "other", description = "My other param") )] struct MyToolMultiParams {} impl MyToolMultiParams { async fn my_tool_multi_params( &self, agent_context: &dyn AgentContext, test: &str, other: &str, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {test} {other}").into()) } } #[derive(Clone, Tool)] #[tool(description = "Hello tool")] struct MyToolNoArgs { test: String, } impl MyToolNoArgs { async fn my_tool_no_args( &self, agent_context: &dyn AgentContext, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } #[derive(Clone, Tool)] #[tool(description = "Hello tool")] struct MyToolLifetime<'a> { test: &'a str, } impl MyToolLifetime<'_> { async fn my_tool_lifetime( &self, agent_context: &dyn AgentContext, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } const DESCRIPTION: &str = "Hello tool"; #[derive(Clone, Tool)] #[tool(description = DESCRIPTION)] struct MyToolConst<'a> { test: &'a str, } impl MyToolConst<'_> { async fn my_tool_const( &self, agent_context: &dyn AgentContext, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } #[derive(Clone, Tool)] #[tool(description = DESCRIPTION, param(name = "test", description = "My param", json_type = "number") )] struct MyToolNumber; impl MyToolNumber { async fn my_tool_number( &self, agent_context: &dyn AgentContext, test: &usize, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } #[derive(Clone, Tool)] #[tool(description = DESCRIPTION, param(name = "test", description = "My param", rust_type = "usize") )] struct MyToolNumber2; impl MyToolNumber2 { async fn my_tool_number_2( &self, agent_context: &dyn AgentContext, test: &usize, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } #[derive(Clone, Tool)] #[tool(description = DESCRIPTION, name = "my_very_renamed_tool", fn_name = "my_very_renamed_tool", param(name = "test", description = "My param", rust_type = "usize") )] struct MyRenamedTool; impl MyRenamedTool { async fn my_very_renamed_tool( &self, agent_context: &dyn AgentContext, test: &usize, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } #[derive(Clone, Tool)] #[tool(description = DESCRIPTION, param(name = "test", description = "My param", required = false) )] struct MyOptionalTool; impl MyOptionalTool { async fn my_optional_tool( &self, agent_context: &dyn AgentContext, test: &Option<String>, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } #[derive(Clone, Tool)] #[tool(description = DESCRIPTION, param(name = "test", description = "My param", rust_type = "Option<usize>") )] struct MyOptionalTool2; impl MyOptionalTool2 { async fn my_optional_tool_2( &self, agent_context: &dyn AgentContext, test: &Option<usize>, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } #[derive(Clone, Tool)] #[tool(description = DESCRIPTION, param(name = "test", description = "My param") )] struct MyGenericTool<S: Send + Sync + Clone> { thing: S, } impl<S: Send + Sync + Clone> MyGenericTool<S> { async fn my_generic_tool( &self, agent_context: &dyn AgentContext, test: &str, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello world").into()) } } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_derive_vec_argument_pass.rs ================================================ #![allow(unused_variables)] use swiftide::chat_completion::{errors::ToolError, ToolOutput}; use swiftide::traits::AgentContext; use swiftide_macros::Tool; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, swiftide::reexports::schemars::JsonSchema)] struct CustomType { value: String, } #[derive(Clone, Tool)] #[tool( description = "Tool that takes a Vec<CustomType>", param(name = "items", description = "items", rust_type = "Vec<CustomType>") )] struct VecTool; impl VecTool { async fn vec_tool( &self, agent_context: &dyn AgentContext, items: Vec<CustomType>, ) -> Result<ToolOutput, ToolError> { Ok(format!("Received {} items", items.len()).into()) } } #[derive(Clone, Tool)] #[tool( description = "Tool that takes nested Vec<CustomType>", param(name = "items", description = "nested items", rust_type = "Vec<Vec<CustomType>>") )] struct NestedVecTool; impl NestedVecTool { async fn nested_vec_tool( &self, agent_context: &dyn AgentContext, items: Vec<Vec<CustomType>>, ) -> Result<ToolOutput, ToolError> { Ok(format!("Received {} groups", items.len()).into()) } } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_missing_arg_fail.rs ================================================ use swiftide::chat_completion::errors::ToolError; use swiftide::chat_completion::ToolOutput; use swiftide::traits::AgentContext; #[swiftide_macros::tool( description = "My first tool", param(name = "msg", description = "A message for testing") )] async fn basic_tool( _agent_context: &dyn AgentContext, msg: &str, other: &str, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg}").into()) } const READ_FILE: &str = "Read a file"; #[swiftide_macros::tool( description = READ_FILE, param(name = "number", description = "Number to guess") )] async fn guess_a_number( _context: &dyn AgentContext, number: usize, ) -> Result<ToolOutput, ToolError> { let actual_number = 42; if number == actual_number { Ok("You guessed it!".into()) } else { Ok("Try again!".into()) } } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_missing_arg_fail.stderr ================================================ error: The following parameters are missing from the spec: ["other"] --> tests/tool/tool_missing_arg_fail.rs:5:1 | 5 | / #[swiftide_macros::tool( 6 | | description = "My first tool", 7 | | param(name = "msg", description = "A message for testing") 8 | | )] | |__^ | = note: this error originates in the attribute macro `swiftide_macros::tool` (in Nightly builds, run with -Z macro-backtrace for more info) ================================================ FILE: swiftide-macros/tests/tool/tool_missing_parameter_fail.rs ================================================ #[swiftide_macros::tool( description = "My first tool", param(name = "Message", description = "A message for testing") )] async fn basic_tool(_agent_context: &dyn AgentContext, msg: &str) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg}").into()) } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_missing_parameter_fail.stderr ================================================ error: The following parameters are missing from the function signature: ["Message"] --> tests/tool/tool_missing_parameter_fail.rs:1:1 | 1 | / #[swiftide_macros::tool( 2 | | description = "My first tool", 3 | | param(name = "Message", description = "A message for testing") 4 | | )] | |__^ | = note: this error originates in the attribute macro `swiftide_macros::tool` (in Nightly builds, run with -Z macro-backtrace for more info) error: The following parameters are missing from the spec: ["msg"] --> tests/tool/tool_missing_parameter_fail.rs:1:1 | 1 | / #[swiftide_macros::tool( 2 | | description = "My first tool", 3 | | param(name = "Message", description = "A message for testing") 4 | | )] | |__^ | = note: this error originates in the attribute macro `swiftide_macros::tool` (in Nightly builds, run with -Z macro-backtrace for more info) ================================================ FILE: swiftide-macros/tests/tool/tool_multiple_arguments_pass.rs ================================================ use swiftide::chat_completion::{errors::ToolError, ToolOutput}; use swiftide::traits::AgentContext; #[swiftide_macros::tool( description = "My first tool", param(name = "msg", description = "A message for testing"), param(name = "other", description = "A message for testing") )] async fn basic_tool( _agent_context: &dyn AgentContext, msg: &str, other: &str, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg}").into()) } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_no_argument_pass.rs ================================================ use swiftide::chat_completion::{errors::ToolError, ToolOutput}; use swiftide::traits::AgentContext; #[swiftide_macros::tool(description = "My first tool")] async fn basic_tool(_agent_context: &dyn AgentContext) -> Result<ToolOutput, ToolError> { Ok(format!("Hello tool").into()) } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_object_argument_pass.rs ================================================ use std::collections::BTreeMap; use serde_json::Value; use swiftide::chat_completion::{errors::ToolError, ToolOutput}; use swiftide::traits::AgentContext; #[swiftide_macros::tool( description = "Tool that accepts object payloads", param(name = "payload", description = "Arbitrary JSON object") )] async fn object_tool( _ctx: &dyn AgentContext, payload: BTreeMap<String, Value>, ) -> Result<ToolOutput, ToolError> { Ok(ToolOutput::text(format!("keys={}", payload.len()))) } fn main() {} ================================================ FILE: swiftide-macros/tests/tool/tool_single_argument_pass.rs ================================================ use swiftide::chat_completion::{errors::ToolError, ToolOutput}; use swiftide::traits::AgentContext; #[swiftide_macros::tool( description = "My first tool", param(name = "msg", description = "A message for testing") )] async fn basic_tool(_agent_context: &dyn AgentContext, msg: &str) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg}").into()) } #[swiftide_macros::tool( description = "My first num tool", param( name = "msg", description = "A message for testing", json_type = "number" ) )] async fn basic_tool_num( _agent_context: &dyn AgentContext, msg: i32, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg}").into()) } #[swiftide_macros::tool( description = "My first num tool", param(name = "msg", description = "A message for testing") )] async fn basic_tool_num_no_type( _agent_context: &dyn AgentContext, msg: i32, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg}").into()) } #[swiftide_macros::tool( description = "My first array tool", param( name = "msg", description = "A message for testing", json_type = "array" ) )] async fn basic_tool_vec( _agent_context: &dyn AgentContext, msg: Vec<String>, ) -> Result<ToolOutput, ToolError> { let msg = msg.join(", "); Ok(format!("Hello {msg}").into()) } #[swiftide_macros::tool( description = "My first bool tool", param( name = "msg", description = "A message for testing", json_type = "boolean" ) )] async fn basic_tool_bool( _agent_context: &dyn AgentContext, msg: bool, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg}").into()) } #[swiftide_macros::tool( description = "My first num slice tool", param( name = "msg", description = "A message for testing", json_type = "array" ) )] async fn basic_tool_num_slice( _agent_context: &dyn AgentContext, msg: &[i32], ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg:?}").into()) } #[swiftide_macros::tool( description = "My first num slice tool", param(name = "msg", description = "A message for testing") )] async fn basic_tool_num_optional( _agent_context: &dyn AgentContext, msg: Option<i32>, ) -> Result<ToolOutput, ToolError> { Ok(format!("Hello {msg:?}").into()) } fn main() {} ================================================ FILE: swiftide-macros/tests/tool.rs ================================================ #[rustversion::attr(nightly, ignore = "nightly has different output")] #[test] fn test_tool() { let t = trybuild::TestCases::new(); t.pass("tests/tool/tool_single_argument_pass.rs"); t.pass("tests/tool/tool_no_argument_pass.rs"); t.pass("tests/tool/tool_multiple_arguments_pass.rs"); t.pass("tests/tool/tool_object_argument_pass.rs"); t.compile_fail("tests/tool/tool_missing_arg_fail.rs"); t.compile_fail("tests/tool/tool_missing_parameter_fail.rs"); } #[rustversion::attr(nightly, ignore = "nightly has different output")] #[test] fn test_tool_derive() { let t = trybuild::TestCases::new(); t.pass("tests/tool/tool_derive_pass.rs"); t.pass("tests/tool/tool_derive_vec_argument_pass.rs"); t.compile_fail("tests/tool/tool_derive_missing_description.rs"); } ================================================ FILE: swiftide-query/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-query" version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] anyhow = { workspace = true } async-trait = { workspace = true } derive_builder = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true } num_cpus = { workspace = true } tracing = { workspace = true } indoc = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tera = { workspace = true } # Internal swiftide-core = { path = "../swiftide-core", version = "0.32.1" } [dev-dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } insta = { workspace = true } [lints] workspace = true [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: swiftide-query/src/answers/mod.rs ================================================ //! Given a query, generate an answer mod simple; pub use simple::*; ================================================ FILE: swiftide-query/src/answers/simple.rs ================================================ //! Generate an answer based on the current query use std::sync::Arc; use swiftide_core::{ Answer, document::Document, indexing::SimplePrompt, prelude::*, prompt::Prompt, querying::{Query, states}, }; /// Generate an answer based on the current query /// /// For most general purposes, this transformer should provide a sensible default. It takes either /// a transformation that has already been applied to the documents (in `Query::current`), or the /// documents themselves, and will then feed them as context with the _original_ question to an llm /// to generate an answer. /// /// For the template context, the following variables are available: /// - **question**: The original question asked by the user /// - **original**: Alias for `question` /// - **current**: The current transformed query /// - **documents**: The documents to use as context /// /// Optionally, a custom document template can be provided to render the documents in a specific /// way. #[derive(Debug, Clone, Builder)] pub struct Simple { #[builder(setter(custom))] client: Arc<dyn SimplePrompt>, #[builder(default = "default_prompt()")] prompt_template: Prompt, #[builder(default, setter(into, strip_option))] document_template: Option<Prompt>, } impl Simple { pub fn builder() -> SimpleBuilder { SimpleBuilder::default() } /// Builds a new simple answer generator from a client that implements [`SimplePrompt`]. /// /// # Panics /// /// Panics if the build failed pub fn from_client(client: impl SimplePrompt + 'static) -> Simple { SimpleBuilder::default() .client(client) .to_owned() .build() .expect("Failed to build Simple") } } impl SimpleBuilder { pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self { self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>); self } } fn default_prompt() -> Prompt { indoc::indoc! {" Answer the following question based on the context provided: {{ question }} ## Constraints * Do not include any information that is not in the provided context. * If the question cannot be answered by the provided context, state that it cannot be answered. * Answer the question completely and format it as markdown. ## Context --- {{ documents }} --- "} .into() } #[async_trait] impl Answer for Simple { #[tracing::instrument(skip_all)] async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>> { let mut context = tera::Context::new(); context.insert("question", query.original()); context.insert("original", query.original()); context.insert("current", query.current()); // If there is a current transformation that is different from the original (transformed) // query, use those as documents (i.e. a summary) let documents = if !query.current().is_empty() && query .history() .iter() .rfind(|e| e.is_retrieval()) .is_some_and(|h| h.before() != query.current()) { query.current().to_string() } else if let Some(template) = &self.document_template { let mut rendered_documents = Vec::new(); for document in query.documents() { let rendered = template .clone() .with_context(tera::Context::from_serialize(document)?) .render()?; rendered_documents.push(rendered); } rendered_documents.join("\n---\n") } else { query .documents() .iter() .map(Document::content) .collect::<Vec<_>>() .join("\n---\n") }; context.insert("documents", &documents); let answer = self .client .prompt(self.prompt_template.clone().with_context(context)) .await?; Ok(query.answered(answer)) } } #[cfg(test)] mod test { use std::sync::Mutex; use insta::assert_snapshot; use swiftide_core::{MockSimplePrompt, indexing::Metadata, querying::TransformationEvent}; use super::*; assert_default_prompt_snapshot!("question" => "What is love?", "documents" => "My context"); #[tokio::test] async fn test_uses_current_if_present() { let mut mock_client = MockSimplePrompt::new(); // I'll buy a beer for the first person who can think of a less insane way to do this let received_prompt = Arc::new(Mutex::new(None)); let cloned = received_prompt.clone(); mock_client .expect_prompt() .withf(move |prompt| { cloned.lock().unwrap().replace(prompt.clone()); true }) .once() .returning(|_| Ok(String::default())); let documents = vec![ Document::new("First document", Some(Metadata::from(("some", "metadata")))), Document::new( "Second document", Some(Metadata::from(("other", "metadata"))), ), ]; let query: Query<states::Retrieved> = Query::builder() .original("original") .current("A fictional generated summary") .state(states::Retrieved) .transformation_history(vec![TransformationEvent::Retrieved { before: "abc".to_string(), after: "abc".to_string(), documents: documents.clone(), }]) .documents(documents) .build() .unwrap(); let transformer = Simple::builder().client(mock_client).build().unwrap(); transformer.answer(query).await.unwrap(); let received_prompt = received_prompt.lock().unwrap().take().unwrap(); let rendered = received_prompt.render().unwrap(); assert_snapshot!(rendered); } #[tokio::test] async fn test_custom_document_template() { let mut mock_client = MockSimplePrompt::new(); // I'll buy a beer for the first person who can think of a less insane way to do this let received_prompt = Arc::new(Mutex::new(None)); let cloned = received_prompt.clone(); mock_client .expect_prompt() .withf(move |prompt| { cloned.lock().unwrap().replace(prompt.clone()); true }) .once() .returning(|_| Ok(String::default())); let documents = vec![ Document::new("First document", Some(Metadata::from(("some", "metadata")))), Document::new( "Second document", Some(Metadata::from(("other", "metadata"))), ), ]; let query: Query<states::Retrieved> = Query::builder() .original("original") .current(String::default()) .state(states::Retrieved) .transformation_history(vec![TransformationEvent::Retrieved { before: "abc".to_string(), after: "abc".to_string(), documents: documents.clone(), }]) .documents(documents) .build() .unwrap(); let transformer = Simple::builder() .client(mock_client) .document_template(indoc::indoc! {" {% for key, value in metadata -%} {{ key }}: {{ value }} {% endfor -%} {{ content }}"}) .build() .unwrap(); transformer.answer(query).await.unwrap(); let received_prompt = received_prompt.lock().unwrap().take().unwrap(); let rendered = received_prompt.render().unwrap(); assert_snapshot!(rendered); } } ================================================ FILE: swiftide-query/src/answers/snapshots/swiftide_query__answers__simple__test__custom_document_template.snap ================================================ --- source: swiftide-query/src/answers/simple.rs expression: rendered --- Answer the following question based on the context provided: original ## Constraints * Do not include any information that is not in the provided context. * If the question cannot be answered by the provided context, state that it cannot be answered. * Answer the question completely and format it as markdown. ## Context --- some: metadata First document --- other: metadata Second document --- ================================================ FILE: swiftide-query/src/answers/snapshots/swiftide_query__answers__simple__test__default_prompt.snap ================================================ --- source: swiftide-query/src/answers/simple.rs expression: prompt.render().await.unwrap() --- Answer the following question based on the context provided: What is love? ## Constraints * Do not include any information that is not in the provided context. * If the question cannot be answered by the provided context, state that it cannot be answered. * Answer the question completely and format it as markdown. ## Context --- My context --- ================================================ FILE: swiftide-query/src/answers/snapshots/swiftide_query__answers__simple__test__uses_current_if_present.snap ================================================ --- source: swiftide-query/src/answers/simple.rs expression: rendered --- Answer the following question based on the context provided: original ## Constraints * Do not include any information that is not in the provided context. * If the question cannot be answered by the provided context, state that it cannot be answered. * Answer the question completely and format it as markdown. ## Context --- A fictional generated summary --- ================================================ FILE: swiftide-query/src/evaluators/mod.rs ================================================ //! This module contains evaluators for evaluating the quality of a pipeline. pub mod ragas; ================================================ FILE: swiftide-query/src/evaluators/ragas.rs ================================================ //! The Ragas evaluator allows you to export a RAGAS compatible JSON dataset. //! //! RAGAS requires a ground truth to compare to. You can either record the answers for an initial //! dataset, or provide the ground truth yourself. //! //! Refer to the ragas documentation on how to use the dataset or take a look at a more involved //! example at [swiftide-tutorials](https://github.com/bosun-ai/swiftide-tutorial). //! //! # Example //! //! ```ignore //! # use swiftide_query::*; //! # use anyhow::{Result, Context}; //! # #[tokio::main] //! # async fn main() -> anyhow::Result<()> { //! //! let openai = swiftide::integrations::openai::OpenAi::default(); //! let qdrant = swiftide::integrations::qdrant::Qdrant::default(); //! //! let ragas = evaluators::ragas::Ragas::from_prepared_questions(questions); //! //! let pipeline = query::Pipeline::default() //! .evaluate_with(ragas.clone()) //! .then_transform_query(query_transformers::GenerateSubquestions::from_client(openai.clone())) //! .then_transform_query(query_transformers::Embed::from_client( //! openai.clone(), //! )) //! .then_retrieve(qdrant.clone()) //! .then_answer(answers::Simple::from_client(openai.clone())); //! //! pipeline.query_all(ragas.questions().await).await.unwrap(); //! //! std::fs::write("output.json", ragas.to_json().await).unwrap(); //! # Ok(()) //! # } use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::json; use std::{collections::HashMap, str::FromStr, sync::Arc}; use tokio::sync::RwLock; use swiftide_core::{ EvaluateQuery, querying::{Query, QueryEvaluation, states}, }; /// Ragas evaluator to be used in a pipeline #[derive(Debug, Clone)] pub struct Ragas { dataset: Arc<RwLock<EvaluationDataSet>>, } /// Row structure for RAGAS compatible JSON #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct EvaluationData { question: String, answer: String, contexts: Vec<String>, ground_truth: String, } /// Dataset for RAGAS compatible JSON, indexed by question #[derive(Debug, Clone)] pub struct EvaluationDataSet(HashMap<String, EvaluationData>); impl Ragas { /// Builds a new Ragas evaluator from a list of questions or a list of tuples with questions and /// ground truths. You can also call `parse` to load a dataset from a JSON string. pub fn from_prepared_questions(questions: impl Into<EvaluationDataSet>) -> Self { Ragas { dataset: Arc::new(RwLock::new(questions.into())), } } pub async fn questions(&self) -> Vec<Query<states::Pending>> { self.dataset.read().await.0.keys().map(Into::into).collect() } /// Records the current answers as ground truths in the dataset pub async fn record_answers_as_ground_truth(&self) { self.dataset.write().await.record_answers_as_ground_truth(); } /// Outputs the dataset as a JSON string compatible with RAGAS pub async fn to_json(&self) -> String { self.dataset.read().await.to_json() } } #[async_trait] impl EvaluateQuery for Ragas { #[tracing::instrument(skip_all)] async fn evaluate(&self, query: QueryEvaluation) -> Result<()> { let mut dataset = self.dataset.write().await; dataset.upsert_evaluation(&query) } } impl EvaluationDataSet { pub(crate) fn record_answers_as_ground_truth(&mut self) { for data in self.0.values_mut() { data.ground_truth.clone_from(&data.answer); } } pub(crate) fn upsert_evaluation(&mut self, query: &QueryEvaluation) -> Result<()> { match query { QueryEvaluation::RetrieveDocuments(query) => self.upsert_retrieved_documents(query), QueryEvaluation::AnswerQuery(query) => self.upsert_answer(query), } } // For each upsort, check if it exists and update it, or return an error fn upsert_retrieved_documents(&mut self, query: &Query<states::Retrieved>) -> Result<()> { let question = query.original(); let data = self .0 .get_mut(question) .ok_or_else(|| anyhow::anyhow!("Question not found"))?; data.contexts = query .documents() .iter() .map(|d| d.content().to_string()) .collect::<Vec<_>>(); Ok(()) } fn upsert_answer(&mut self, query: &Query<states::Answered>) -> Result<()> { let question = query.original(); let data = self .0 .get_mut(question) .ok_or_else(|| anyhow::anyhow!("Question not found"))?; data.answer = query.answer().to_string(); Ok(()) } /// Outputs json for ragas /// /// # Format /// /// ```json /// [ /// { /// "question": "What is the capital of France?", /// "answer": "Paris", /// "contexts": ["Paris is the capital of France"], /// "ground_truth": "Paris" /// }, /// { /// "question": "What is the capital of France?", /// "answer": "Paris", /// "contexts": ["Paris is the capital of France"], /// "ground_truth": "Paris" /// } /// ] /// ``` pub(crate) fn to_json(&self) -> String { let json_value = json!(self.0.values().collect::<Vec<_>>()); serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| json_value.to_string()) } } // Can just do a list of questions leaving ground truth, answers, contexts empty impl From<Vec<String>> for EvaluationDataSet { fn from(val: Vec<String>) -> Self { EvaluationDataSet( val.into_iter() .map(|question| { ( question.clone(), EvaluationData { question, ..EvaluationData::default() }, ) }) .collect(), ) } } impl From<&[String]> for EvaluationDataSet { fn from(val: &[String]) -> Self { EvaluationDataSet( val.iter() .map(|question| { ( question.clone(), EvaluationData { question: question.clone(), ..EvaluationData::default() }, ) }) .collect(), ) } } // Can take a list of tuples for questions and ground truths impl From<Vec<(String, String)>> for EvaluationDataSet { fn from(val: Vec<(String, String)>) -> Self { EvaluationDataSet( val.into_iter() .map(|(question, ground_truth)| { ( question.clone(), EvaluationData { question, ground_truth, ..EvaluationData::default() }, ) }) .collect(), ) } } /// Parse an existing dataset from a JSON string impl FromStr for EvaluationDataSet { type Err = serde_json::Error; fn from_str(val: &str) -> std::prelude::v1::Result<Self, Self::Err> { let data: Vec<EvaluationData> = serde_json::from_str(val)?; Ok(EvaluationDataSet( data.into_iter() .map(|data| (data.question.clone(), data)) .collect(), )) } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use swiftide_core::querying::{Query, QueryEvaluation}; use tokio::sync::RwLock; #[tokio::test] async fn test_ragas_from_prepared_questions() { let questions = vec!["What is Rust?".to_string(), "What is Tokio?".to_string()]; let ragas = Ragas::from_prepared_questions(questions.clone()); let stored_questions = ragas.questions().await; assert_eq!(stored_questions.len(), questions.len()); for question in questions { assert!(stored_questions.iter().any(|q| q.original() == question)); } } #[tokio::test] async fn test_ragas_record_answers_as_ground_truth() { let dataset = Arc::new(RwLock::new(EvaluationDataSet::from(vec![( "What is Rust?".to_string(), "A programming language".to_string(), )]))); let ragas = Ragas { dataset: dataset.clone(), }; { let mut lock = dataset.write().await; let data = lock.0.get_mut("What is Rust?").unwrap(); data.answer = "A systems programming language".to_string(); } ragas.record_answers_as_ground_truth().await; let updated_data = ragas.dataset.read().await; let data = updated_data.0.get("What is Rust?").unwrap(); assert_eq!(data.ground_truth, "A systems programming language"); } #[tokio::test] async fn test_ragas_to_json() { let dataset = EvaluationDataSet::from(vec![( "What is Rust?".to_string(), "A programming language".to_string(), )]); let ragas = Ragas { dataset: Arc::new(RwLock::new(dataset)), }; let json_output = ragas.to_json().await; let expected_json = "[\n {\n \"answer\": \"\",\n \"contexts\": [],\n \"ground_truth\": \"A programming language\",\n \"question\": \"What is Rust?\"\n }\n]"; assert_eq!(json_output, expected_json); } #[tokio::test] async fn test_evaluate_query_upsert_retrieved_documents() { let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]); let ragas = Ragas { dataset: Arc::new(RwLock::new(dataset.clone())), }; let query = Query::builder() .original("What is Rust?") .documents(vec!["Rust is a language".into()]) .build() .unwrap(); let evaluation = QueryEvaluation::RetrieveDocuments(query.clone()); ragas.evaluate(evaluation).await.unwrap(); let updated_data = ragas.dataset.read().await; let data = updated_data.0.get("What is Rust?").unwrap(); assert_eq!(data.contexts, vec!["Rust is a language"]); } #[tokio::test] async fn test_evaluate_query_upsert_answer() { let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]); let ragas = Ragas { dataset: Arc::new(RwLock::new(dataset.clone())), }; let query = Query::builder() .original("What is Rust?") .current("A systems programming language") .build() .unwrap(); let evaluation = QueryEvaluation::AnswerQuery(query.clone()); ragas.evaluate(evaluation).await.unwrap(); let updated_data = ragas.dataset.read().await; let data = updated_data.0.get("What is Rust?").unwrap(); assert_eq!(data.answer, "A systems programming language"); } #[tokio::test] async fn test_evaluation_dataset_record_answers_as_ground_truth() { let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]); let data = dataset.0.get_mut("What is Rust?").unwrap(); data.answer = "A programming language".to_string(); dataset.record_answers_as_ground_truth(); let data = dataset.0.get("What is Rust?").unwrap(); assert_eq!(data.ground_truth, "A programming language"); } #[tokio::test] async fn test_evaluation_dataset_to_json() { let dataset = EvaluationDataSet::from(vec![( "What is Rust?".to_string(), "A programming language".to_string(), )]); let json_output = dataset.to_json(); let expected_json = "[\n {\n \"answer\": \"\",\n \"contexts\": [],\n \"ground_truth\": \"A programming language\",\n \"question\": \"What is Rust?\"\n }\n]"; assert_eq!(json_output, expected_json); } #[tokio::test] async fn test_evaluation_dataset_upsert_retrieved_documents() { let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]); let query = Query::builder() .original("What is Rust?") .documents(vec!["Rust is a language".into()]) .build() .unwrap(); dataset .upsert_evaluation(&QueryEvaluation::RetrieveDocuments(query.clone())) .unwrap(); let data = dataset.0.get("What is Rust?").unwrap(); assert_eq!(data.contexts, vec!["Rust is a language"]); } #[tokio::test] async fn test_evaluation_dataset_upsert_answer() { let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]); let query = Query::builder() .original("What is Rust?") .current("A systems programming language") .build() .unwrap(); dataset .upsert_evaluation(&QueryEvaluation::AnswerQuery(query.clone())) .unwrap(); let data = dataset.0.get("What is Rust?").unwrap(); assert_eq!(data.answer, "A systems programming language"); } } ================================================ FILE: swiftide-query/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] pub mod answers; mod query; pub mod query_transformers; pub mod response_transformers; pub use query::*; pub mod evaluators; ================================================ FILE: swiftide-query/src/query/mod.rs ================================================ mod pipeline; pub use pipeline::Pipeline; ================================================ FILE: swiftide-query/src/query/pipeline.rs ================================================ //! A query pipeline can be used to answer a user query //! //! The pipeline has a sequence of steps: //! 1. Transform the query (i.e. Generating subquestions, embeddings) //! 2. Retrieve documents from storage //! 3. Transform these documents into a suitable context for answering //! 4. Answering the query //! //! WARN: The query pipeline is in a very early stage! //! //! Under the hood, it uses a [`SearchStrategy`] that an implementor of [`Retrieve`] (i.e. Qdrant) //! must implement. //! //! A query pipeline is lazy and only runs when query is called. use futures_util::TryFutureExt as _; use std::sync::Arc; use swiftide_core::{ EvaluateQuery, prelude::*, querying::{ Answer, Query, QueryState, QueryStream, Retrieve, SearchStrategy, TransformQuery, TransformResponse, search_strategies::SimilaritySingleEmbedding, states, }, }; use tokio::sync::mpsc::Sender; /// The starting point of a query pipeline pub struct Pipeline< 'stream, STRATEGY: SearchStrategy = SimilaritySingleEmbedding, STATE: QueryState = states::Pending, > { search_strategy: STRATEGY, stream: QueryStream<'stream, STATE>, query_sender: Sender<Result<Query<states::Pending>>>, evaluator: Option<Arc<Box<dyn EvaluateQuery>>>, default_concurrency: usize, } /// By default the [`SearchStrategy`] is [`SimilaritySingleEmbedding`], which embed the current /// query and returns a collection of documents. impl Default for Pipeline<'_, SimilaritySingleEmbedding> { fn default() -> Self { let stream = QueryStream::default(); Self { search_strategy: SimilaritySingleEmbedding::default(), query_sender: stream .sender .clone() .expect("Pipeline received stream without query entrypoint"), stream, evaluator: None, default_concurrency: num_cpus::get(), } } } impl<'a, STRATEGY: SearchStrategy> Pipeline<'a, STRATEGY> { /// Create a query pipeline from a [`SearchStrategy`] /// /// # Panics /// /// Panics if the inner stream fails to build #[must_use] pub fn from_search_strategy(strategy: STRATEGY) -> Pipeline<'a, STRATEGY> { let stream = QueryStream::default(); Pipeline { search_strategy: strategy, query_sender: stream .sender .clone() .expect("Pipeline received stream without query entrypoint"), stream, evaluator: None, default_concurrency: num_cpus::get(), } } } impl<'stream: 'static, STRATEGY> Pipeline<'stream, STRATEGY, states::Pending> where STRATEGY: SearchStrategy, { /// Evaluate queries with an evaluator #[must_use] pub fn evaluate_with<T: EvaluateQuery + 'stream>(mut self, evaluator: T) -> Self { self.evaluator = Some(Arc::new(Box::new(evaluator))); self } /// Transform a query into something else, see [`crate::query_transformers`] #[must_use] pub fn then_transform_query<T: TransformQuery + 'stream>( self, transformer: T, ) -> Pipeline<'stream, STRATEGY, states::Pending> { let transformer = Arc::new(transformer); let Pipeline { stream, query_sender, search_strategy, evaluator, default_concurrency, } = self; let new_stream = stream .map_ok(move |query| { let transformer = Arc::clone(&transformer); let span = tracing::info_span!("then_transform_query", query = ?query); tokio::spawn( async move { let transformed_query = transformer.transform_query(query).await?; tracing::debug!( transformed_query = transformed_query.current(), query_transformer = transformer.name(), "Transformed query" ); Ok(transformed_query) } .instrument(span.or_current()), ) .err_into::<anyhow::Error>() }) .try_buffer_unordered(default_concurrency) .map(|x| x.and_then(|x| x)); Pipeline { stream: new_stream.boxed().into(), search_strategy, query_sender, evaluator, default_concurrency, } } } impl<'stream: 'static, STRATEGY: SearchStrategy + 'stream> Pipeline<'stream, STRATEGY, states::Pending> { /// Executes the query based on a search query with a retriever #[must_use] pub fn then_retrieve<T: ToOwned<Owned = impl Retrieve<STRATEGY> + 'stream>>( self, retriever: T, ) -> Pipeline<'stream, STRATEGY, states::Retrieved> { let retriever = Arc::new(retriever.to_owned()); let Pipeline { stream, query_sender, search_strategy, evaluator, default_concurrency, } = self; let strategy_for_stream = search_strategy.clone(); let evaluator_for_stream = evaluator.clone(); let new_stream = stream .map_ok(move |query| { let search_strategy = strategy_for_stream.clone(); let retriever = Arc::clone(&retriever); let span = tracing::info_span!("then_retrieve", query = ?query); let evaluator_for_stream = evaluator_for_stream.clone(); tokio::spawn( async move { let result = retriever.retrieve(&search_strategy, query).await?; tracing::debug!( num_documents = result.documents().len(), total_bytes = result .documents() .iter() .map(|d| d.bytes().len()) .sum::<usize>(), "Retrieved documents" ); if let Some(evaluator) = evaluator_for_stream.as_ref() { evaluator.evaluate(result.clone().into()).await?; Ok(result) } else { Ok(result) } } .instrument(span.or_current()), ) .err_into::<anyhow::Error>() }) .try_buffer_unordered(default_concurrency) .map(|x| x.and_then(|x| x)); Pipeline { stream: new_stream.boxed().into(), search_strategy: search_strategy.clone(), query_sender, evaluator, default_concurrency, } } } impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> { /// Transforms a retrieved query into something else #[must_use] pub fn then_transform_response<T: TransformResponse + 'stream>( self, transformer: T, ) -> Pipeline<'stream, STRATEGY, states::Retrieved> { let transformer = Arc::new(transformer); let Pipeline { stream, query_sender, search_strategy, evaluator, default_concurrency, } = self; let new_stream = stream .map_ok(move |query| { let transformer = Arc::clone(&transformer); let span = tracing::info_span!("then_transform_response", query = ?query); tokio::spawn( async move { let transformed_query = transformer.transform_response(query).await?; tracing::debug!( transformed_query = transformed_query.current(), response_transformer = transformer.name(), "Transformed response" ); Ok(transformed_query) } .instrument(span.or_current()), ) .err_into::<anyhow::Error>() }) .try_buffer_unordered(default_concurrency) .map(|x| x.and_then(|x| x)); Pipeline { stream: new_stream.boxed().into(), search_strategy, query_sender, evaluator, default_concurrency, } } } impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> { /// Generates an answer based on previous transformations #[must_use] pub fn then_answer<T: Answer + 'stream>( self, answerer: T, ) -> Pipeline<'stream, STRATEGY, states::Answered> { let answerer = Arc::new(answerer); let Pipeline { stream, query_sender, search_strategy, evaluator, default_concurrency, } = self; let evaluator_for_stream = evaluator.clone(); let new_stream = stream .map_ok(move |query: Query<states::Retrieved>| { let answerer = Arc::clone(&answerer); let span = tracing::info_span!("then_answer", query = ?query); let evaluator_for_stream = evaluator_for_stream.clone(); tokio::spawn( async move { tracing::debug!(answerer = answerer.name(), "Answering query"); let result = answerer.answer(query).await?; if let Some(evaluator) = evaluator_for_stream.as_ref() { evaluator.evaluate(result.clone().into()).await?; Ok(result) } else { Ok(result) } } .instrument(span.or_current()), ) .err_into::<anyhow::Error>() }) .try_buffer_unordered(default_concurrency) .map(|x| x.and_then(|x| x)); Pipeline { stream: new_stream.boxed().into(), search_strategy, query_sender, evaluator, default_concurrency, } } } impl<STRATEGY: SearchStrategy> Pipeline<'_, STRATEGY, states::Answered> { /// Runs the pipeline with a user query, accepts `&str` as well. /// /// # Errors /// /// Errors if any of the transformations failed or no response was found #[tracing::instrument(skip_all, name = "query_pipeline.query")] pub async fn query( mut self, query: impl Into<Query<states::Pending>>, ) -> Result<Query<states::Answered>> { tracing::debug!("Sending query"); let now = std::time::Instant::now(); self.query_sender.send(Ok(query.into())).await?; let answer = self.stream.try_next().await?.ok_or_else(|| { anyhow::anyhow!("Pipeline did not receive a response from the query stream") }); let elapsed_in_seconds = now.elapsed().as_secs(); tracing::warn!( elapsed_in_seconds, "Answered query in {} seconds", elapsed_in_seconds ); answer } /// Runs the pipeline with a user query, accepts `&str` as well. /// /// Does not consume the pipeline and requires a mutable reference. This allows /// the pipeline to be reused. /// /// # Errors /// /// Errors if any of the transformations failed or no response was found #[tracing::instrument(skip_all, name = "query_pipeline.query_mut")] pub async fn query_mut( &mut self, query: impl Into<Query<states::Pending>>, ) -> Result<Query<states::Answered>> { tracing::warn!("Sending query"); let now = std::time::Instant::now(); self.query_sender.send(Ok(query.into())).await?; let answer = self .stream .by_ref() .take(1) .try_next() .await? .ok_or_else(|| { anyhow::anyhow!("Pipeline did not receive a response from the query stream") }); tracing::debug!(?answer, "Received an answer"); let elapsed_in_seconds = now.elapsed().as_secs(); tracing::warn!( elapsed_in_seconds, "Answered query in {} seconds", elapsed_in_seconds ); answer } /// Runs the pipeline with multiple queries /// /// # Errors /// /// Errors if any of the transformations failed, no response was found, or the stream was /// closed. #[tracing::instrument(skip_all, name = "query_pipeline.query_all")] pub async fn query_all( self, queries: Vec<impl Into<Query<states::Pending>> + Clone>, ) -> Result<Vec<Query<states::Answered>>> { tracing::warn!("Sending queries"); let now = std::time::Instant::now(); let Pipeline { query_sender, mut stream, .. } = self; for query in &queries { query_sender.send(Ok(query.clone().into())).await?; } tracing::info!("All queries sent"); let mut results = vec![]; while let Some(result) = stream.try_next().await? { tracing::debug!(?result, "Received an answer"); results.push(result); if results.len() == queries.len() { break; } } let elapsed_in_seconds = now.elapsed().as_secs(); tracing::warn!( num_queries = queries.len(), elapsed_in_seconds, "Answered all queries in {} seconds", elapsed_in_seconds ); Ok(results) } } #[cfg(test)] mod test { use swiftide_core::{ MockAnswer, MockTransformQuery, MockTransformResponse, querying::search_strategies, }; use super::*; #[tokio::test] async fn test_closures_in_each_step() { let pipeline = Pipeline::default() .then_transform_query(move |query: Query<states::Pending>| Ok(query)) .then_retrieve( move |_: &search_strategies::SimilaritySingleEmbedding, query: Query<states::Pending>| { Ok(query.retrieved_documents(vec![])) }, ) .then_transform_response(Ok) .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok"))); let response = pipeline.query("What").await.unwrap(); assert_eq!(response.answer(), "Ok"); } #[tokio::test] async fn test_all_steps_should_accept_dyn_box() { let mut query_transformer = MockTransformQuery::new(); query_transformer.expect_transform_query().returning(Ok); let mut response_transformer = MockTransformResponse::new(); response_transformer .expect_transform_response() .returning(Ok); let mut answer_transformer = MockAnswer::new(); answer_transformer .expect_answer() .returning(|query| Ok(query.answered("OK"))); let pipeline = Pipeline::default() .then_transform_query(Box::new(query_transformer) as Box<dyn TransformQuery>) .then_retrieve( |_: &search_strategies::SimilaritySingleEmbedding, query: Query<states::Pending>| { Ok(query.retrieved_documents(vec![])) }, ) .then_transform_response(Box::new(response_transformer) as Box<dyn TransformResponse>) .then_answer(Box::new(answer_transformer) as Box<dyn Answer>); let response = pipeline.query("What").await.unwrap(); assert_eq!(response.answer(), "OK"); } #[tokio::test] async fn test_reuse_with_query_mut() { let mut pipeline = Pipeline::default() .then_transform_query(move |query: Query<states::Pending>| Ok(query)) .then_retrieve( move |_: &search_strategies::SimilaritySingleEmbedding, query: Query<states::Pending>| { Ok(query.retrieved_documents(vec![])) }, ) .then_transform_response(Ok) .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok"))); let response = pipeline.query_mut("What").await.unwrap(); assert_eq!(response.answer(), "Ok"); let response = pipeline.query_mut("What").await.unwrap(); assert_eq!(response.answer(), "Ok"); } } ================================================ FILE: swiftide-query/src/query_transformers/embed.rs ================================================ use std::sync::Arc; use swiftide_core::{ indexing::EmbeddingModel, prelude::*, querying::{Query, TransformQuery, states}, }; #[derive(Debug, Clone)] pub struct Embed { embed_model: Arc<dyn EmbeddingModel>, } impl Embed { pub fn from_client(client: impl EmbeddingModel + 'static) -> Embed { Embed { embed_model: Arc::new(client), } } } #[async_trait] impl TransformQuery for Embed { #[tracing::instrument(skip_all)] async fn transform_query( &self, mut query: Query<states::Pending>, ) -> Result<Query<states::Pending>> { let Some(embedding) = self .embed_model .embed(vec![query.current().to_string()]) .await? .pop() else { anyhow::bail!("Failed to embed query") }; query.embedding = Some(embedding); Ok(query) } } ================================================ FILE: swiftide-query/src/query_transformers/generate_subquestions.rs ================================================ //! Generate subquestions for a query //! //! Useful for similarity search where you want a wider vector coverage use std::sync::Arc; use swiftide_core::{ indexing::SimplePrompt, prelude::*, prompt::Prompt, querying::{Query, TransformQuery, states}, }; #[derive(Debug, Clone, Builder)] pub struct GenerateSubquestions { #[builder(setter(custom))] client: Arc<dyn SimplePrompt>, #[builder(default = "default_prompt()")] prompt_template: Prompt, #[builder(default = "5")] num_questions: usize, } impl GenerateSubquestions { pub fn builder() -> GenerateSubquestionsBuilder { GenerateSubquestionsBuilder::default() } /// Builds a new subquestions generator from a client that implements [`SimplePrompt`] /// /// # Panics /// /// Panics if the build failed pub fn from_client(client: impl SimplePrompt + 'static) -> GenerateSubquestions { GenerateSubquestionsBuilder::default() .client(client) .to_owned() .build() .expect("Failed to build GenerateSubquestions") } } impl GenerateSubquestionsBuilder { pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self { self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>); self } } fn default_prompt() -> Prompt { indoc::indoc!(" Your job is to help a query tool find the right context. Given the following question: {{question}} Please think of {{num_questions}} additional questions that can help answering the original question. Especially consider what might be relevant to answer the question, like dependencies, usage and structure of the code. Please respond with the original question and the additional questions only. ## Example - {{question}} - Additional question 1 - Additional question 2 - Additional question 3 - Additional question 4 - Additional question 5 ").into() } #[async_trait] impl TransformQuery for GenerateSubquestions { #[tracing::instrument(skip_self)] async fn transform_query( &self, mut query: Query<states::Pending>, ) -> Result<Query<states::Pending>> { let new_query = self .client .prompt( self.prompt_template .clone() .with_context_value("question", query.current()) .with_context_value("num_questions", self.num_questions), ) .await?; query.transformed_query(new_query); Ok(query) } } #[cfg(test)] mod test { use super::*; assert_default_prompt_snapshot!("question" => "What is love?", "num_questions" => 5); } ================================================ FILE: swiftide-query/src/query_transformers/mod.rs ================================================ //! Transform queries that are yet to be made mod generate_subquestions; pub use generate_subquestions::GenerateSubquestions; mod embed; mod sparse_embed; pub use embed::Embed; pub use sparse_embed::SparseEmbed; ================================================ FILE: swiftide-query/src/query_transformers/snapshots/swiftide_query__query_transformers__generate_subquestions__test__default_prompt.snap ================================================ --- source: swiftide-query/src/query_transformers/generate_subquestions.rs expression: prompt.render().await.unwrap() --- Your job is to help a query tool find the right context. Given the following question: What is love? Please think of 5 additional questions that can help answering the original question. Especially consider what might be relevant to answer the question, like dependencies, usage and structure of the code. Please respond with the original question and the additional questions only. ## Example - What is love? - Additional question 1 - Additional question 2 - Additional question 3 - Additional question 4 - Additional question 5 ================================================ FILE: swiftide-query/src/query_transformers/sparse_embed.rs ================================================ use std::sync::Arc; use swiftide_core::{ SparseEmbeddingModel, prelude::*, querying::{Query, TransformQuery, states}, }; /// Embed a query with a sparse embedding. #[derive(Debug, Clone)] pub struct SparseEmbed { embed_model: Arc<dyn SparseEmbeddingModel>, } impl SparseEmbed { pub fn from_client(client: impl SparseEmbeddingModel + 'static) -> SparseEmbed { SparseEmbed { embed_model: Arc::new(client), } } } #[async_trait] impl TransformQuery for SparseEmbed { #[tracing::instrument(skip_all)] async fn transform_query( &self, mut query: Query<states::Pending>, ) -> Result<Query<states::Pending>> { let Some(embedding) = self .embed_model .sparse_embed(vec![query.current().to_string()]) .await? .pop() else { anyhow::bail!("Failed to embed query") }; query.sparse_embedding = Some(embedding); Ok(query) } } ================================================ FILE: swiftide-query/src/response_transformers/mod.rs ================================================ //! Transform retrieved queries mod summary; pub use summary::*; ================================================ FILE: swiftide-query/src/response_transformers/snapshots/swiftide_query__response_transformers__summary__test__default_prompt.snap ================================================ --- source: swiftide-query/src/response_transformers/summary.rs expression: prompt.render().await.unwrap() --- Your job is to help a query tool find the right context. Summarize the following documents. ## Constraints * Do not add any information that is not available in the documents. * Summarize comprehensively and ensure no data that might be important is left out. * Summarize as a single markdown document ## Documents --- First document --- --- Second Document --- ================================================ FILE: swiftide-query/src/response_transformers/summary.rs ================================================ use std::sync::Arc; use swiftide_core::{ TransformResponse, indexing::SimplePrompt, prelude::*, prompt::Prompt, querying::{Query, states}, }; #[derive(Debug, Clone, Builder)] pub struct Summary { #[builder(setter(custom))] client: Arc<dyn SimplePrompt>, #[builder(default = "default_prompt()")] prompt_template: Prompt, } impl Summary { pub fn builder() -> SummaryBuilder { SummaryBuilder::default() } /// Builds a new summary generator from a client that implements [`SimplePrompt`]. /// /// Will try to summarize documents using an llm, instructed to preserve as much information as /// possible. /// /// # Panics /// /// Panics if the build failed pub fn from_client(client: impl SimplePrompt + 'static) -> Summary { SummaryBuilder::default() .client(client) .to_owned() .build() .expect("Failed to build Summary") } } impl SummaryBuilder { pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self { self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>); self } } fn default_prompt() -> Prompt { indoc::indoc!( " Your job is to help a query tool find the right context. Summarize the following documents. ## Constraints * Do not add any information that is not available in the documents. * Summarize comprehensively and ensure no data that might be important is left out. * Summarize as a single markdown document ## Documents {% for document in documents -%} --- {{ document.content }} --- {% endfor -%} " ) .into() } #[async_trait] impl TransformResponse for Summary { #[tracing::instrument(skip_all)] async fn transform_response( &self, mut query: Query<states::Retrieved>, ) -> Result<Query<states::Retrieved>> { let new_response = self .client .prompt( self.prompt_template .clone() .with_context_value("documents", query.documents()), ) .await?; query.transformed_response(new_response); Ok(query) } } #[cfg(test)] mod test { use swiftide_core::document::Document; use super::*; assert_default_prompt_snapshot!("documents" => vec![Document::from("First document"), Document::from("Second Document")]); } ================================================ FILE: swiftide-test-utils/Cargo.toml ================================================ cargo-features = ["edition2024"] [package] name = "swiftide-test-utils" publish = false version.workspace = true edition.workspace = true license.workspace = true readme.workspace = true keywords.workspace = true description.workspace = true categories.workspace = true repository.workspace = true homepage.workspace = true [dependencies] swiftide-integrations = { path = "../swiftide-integrations", features = [ "openai", ] } serde = { workspace = true } serde_json = { workspace = true } async-openai = { workspace = true } testcontainers = { workspace = true } wiremock = { workspace = true } [features] default = ["test-utils"] test-utils = [] [package.metadata.docs.rs] all-features = true cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: swiftide-test-utils/src/lib.rs ================================================ // show feature flags in the generated documentation // https://doc.rust-lang.org/rustdoc/unstable-features.html#extensions-to-the-doc-attribute #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, doc(auto_cfg))] #![doc(html_logo_url = "https://github.com/bosun-ai/swiftide/raw/master/images/logo.png")] #[cfg(feature = "test-utils")] mod test_utils; #[cfg(feature = "test-utils")] pub use test_utils::*; ================================================ FILE: swiftide-test-utils/src/test_utils.rs ================================================ #![allow(missing_docs)] #![allow(clippy::missing_panics_doc)] use serde_json::json; use testcontainers::{ ContainerAsync, GenericImage, ImageExt, core::{IntoContainerPort, WaitFor, wait::HttpWaitStrategy}, runners::AsyncRunner, }; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; use swiftide_integrations as integrations; pub fn openai_client( mock_server_uri: &str, embed_model: &str, prompt_model: &str, ) -> integrations::openai::OpenAI { let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server_uri); let async_openai = async_openai::Client::with_config(config); integrations::openai::OpenAI::builder() .client(async_openai) .default_options( integrations::openai::Options::builder() .embed_model(embed_model) .prompt_model(prompt_model) .build() .unwrap(), ) .build() .expect("Can create OpenAI client.") } /// Setup Qdrant container. /// Returns container server and `server_url`. pub async fn start_qdrant() -> (ContainerAsync<GenericImage>, String) { let qdrant = testcontainers::GenericImage::new("qdrant/qdrant", "v1.13.4") .with_exposed_port(6334.into()) .with_exposed_port(6333.into()) .with_wait_for(testcontainers::core::WaitFor::http( HttpWaitStrategy::new("/readyz") .with_port(6333.into()) .with_expected_status_code(200_u16), )) .start() .await .expect("Qdrant started"); let qdrant_url = format!( "http://127.0.0.1:{port}", port = qdrant.get_host_port_ipv4(6334).await.unwrap() ); (qdrant, qdrant_url) } /// Setup Redis container for caching in the test. /// Returns container server and `server_url`. pub async fn start_redis() -> (ContainerAsync<GenericImage>, String) { let redis = testcontainers::GenericImage::new("redis", "7-alpine") .with_exposed_port(6379.into()) .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( "Ready to accept connections", )) .start() .await .expect("Redis started"); let redis_url = format!( "redis://{host}:{port}", host = redis.get_host().await.unwrap(), port = redis.get_host_port_ipv4(6379).await.unwrap() ); (redis, redis_url) } /// Setup Postgres container. /// Returns container server and `server_url`. pub async fn start_postgres() -> (ContainerAsync<GenericImage>, String) { let postgres = testcontainers::GenericImage::new("pgvector/pgvector", "pg17") .with_wait_for(WaitFor::message_on_stdout( "database system is ready to accept connections", )) .with_exposed_port(5432.tcp()) .with_env_var("POSTGRES_USER", "myuser") .with_env_var("POSTGRES_PASSWORD", "mypassword") .with_env_var("POSTGRES_DB", "mydatabase") .start() .await .expect("Failed to start Postgres container"); // Construct the connection URL using the dynamically assigned port let host_port = postgres.get_host_port_ipv4(5432).await.unwrap(); let postgres_url = format!("postgresql://myuser:mypassword@127.0.0.1:{host_port}/mydatabase"); (postgres, postgres_url) } /// Mock embeddings creation endpoint. /// `embeddings_count` controls number of returned embedding vectors. pub async fn mock_embeddings(mock_server: &MockServer, embeddings_count: u8) { let data = (0..embeddings_count) .map(|i| { json!( { "object": "embedding", "embedding": vec![0; 1536], "index": i }) }) .collect::<Vec<serde_json::Value>>(); let data: serde_json::Value = serde_json::Value::Array(data); Mock::given(method("POST")) .and(path("/embeddings")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "object": "list", "data": data, "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } }))) .mount(mock_server) .await; } pub async fn mock_chat_completions(mock_server: &MockServer) { Mock::given(method("POST")) .and(path("/chat/completions")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "id": "chatcmpl-123", "object": "chat.completion", "created": 1_677_652_288, "model": "gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices": [{ "index": 0, "message": { "role": "assistant", "content": "\n\nHello there, how may I assist you today?", }, "logprobs": null, "finish_reason": "stop" }], "usage": { "prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21 } }))) .mount(mock_server) .await; } ================================================ FILE: typos.toml ================================================ [files] # Autogenerated extend-exclude = ["CHANGELOG.md", "cliff.toml"]