Repository: cloudwego/eino Branch: main Commit: e2eea8eaf360 Files: 312 Total size: 2.8 MB Directory structure: gitextract_0ddnn1r5/ ├── .github/ │ ├── .codedev.yml │ ├── .commit-rules.json │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows/ │ ├── pr-check.yml │ ├── tag-notification.yml │ └── tests.yml ├── .gitignore ├── .golangci.yaml ├── .licenserc.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE-APACHE ├── README.md ├── README.zh_CN.md ├── _typos.toml ├── adk/ │ ├── agent_tool.go │ ├── agent_tool_test.go │ ├── call_option.go │ ├── call_option_test.go │ ├── callback.go │ ├── callback_integration_test.go │ ├── callback_test.go │ ├── chatmodel.go │ ├── chatmodel_retry_test.go │ ├── chatmodel_test.go │ ├── config.go │ ├── deterministic_transfer.go │ ├── deterministic_transfer_test.go │ ├── filesystem/ │ │ ├── backend.go │ │ ├── backend_inmemory.go │ │ └── backend_inmemory_test.go │ ├── flow.go │ ├── flow_test.go │ ├── handler.go │ ├── handler_test.go │ ├── instruction.go │ ├── interface.go │ ├── internal/ │ │ └── config.go │ ├── interrupt.go │ ├── interrupt_test.go │ ├── middlewares/ │ │ ├── dynamictool/ │ │ │ └── toolsearch/ │ │ │ ├── toolsearch.go │ │ │ └── toolsearch_test.go │ │ ├── filesystem/ │ │ │ ├── backend.go │ │ │ ├── filesystem.go │ │ │ ├── filesystem_test.go │ │ │ ├── large_tool_result.go │ │ │ ├── large_tool_result_test.go │ │ │ └── prompt.go │ │ ├── patchtoolcalls/ │ │ │ ├── patchtoolcalls.go │ │ │ └── patchtoolcalls_test.go │ │ ├── plantask/ │ │ │ ├── backend_test.go │ │ │ ├── plantask.go │ │ │ ├── plantask_test.go │ │ │ ├── task.go │ │ │ ├── task_create.go │ │ │ ├── task_create_test.go │ │ │ ├── task_get.go │ │ │ ├── task_get_test.go │ │ │ ├── task_list.go │ │ │ ├── task_list_test.go │ │ │ ├── task_update.go │ │ │ └── task_update_test.go │ │ ├── reduction/ │ │ │ ├── consts.go │ │ │ ├── internal/ │ │ │ │ ├── clear_tool_result.go │ │ │ │ ├── clear_tool_result_test.go │ │ │ │ ├── large_tool_result.go │ │ │ │ ├── large_tool_result_test.go │ │ │ │ └── tool_result.go │ │ │ ├── legacy.go │ │ │ ├── reduction.go │ │ │ └── reduction_test.go │ │ ├── skill/ │ │ │ ├── filesystem_backend.go │ │ │ ├── filesystem_backend_test.go │ │ │ ├── prompt.go │ │ │ ├── skill.go │ │ │ └── skill_test.go │ │ └── summarization/ │ │ ├── consts.go │ │ ├── customized_action.go │ │ ├── prompt.go │ │ ├── summarization.go │ │ └── summarization_test.go │ ├── prebuilt/ │ │ ├── deep/ │ │ │ ├── checkpoint_compat_resume_test.go │ │ │ ├── deep.go │ │ │ ├── deep_test.go │ │ │ ├── prompt.go │ │ │ ├── task_tool.go │ │ │ ├── task_tool_test.go │ │ │ ├── testdata/ │ │ │ │ └── _gen/ │ │ │ │ └── generate_test.go │ │ │ └── types.go │ │ ├── integration_test.go │ │ ├── planexecute/ │ │ │ ├── plan_execute.go │ │ │ ├── plan_execute_test.go │ │ │ └── utils.go │ │ └── supervisor/ │ │ ├── supervisor.go │ │ └── supervisor_test.go │ ├── react.go │ ├── react_test.go │ ├── retry_chatmodel.go │ ├── runctx.go │ ├── runctx_test.go │ ├── runner.go │ ├── runner_test.go │ ├── utils.go │ ├── utils_test.go │ ├── workflow.go │ ├── workflow_test.go │ ├── wrappers.go │ └── wrappers_test.go ├── callbacks/ │ ├── aspect_inject.go │ ├── aspect_inject_test.go │ ├── doc.go │ ├── handler_builder.go │ ├── interface.go │ └── interface_test.go ├── components/ │ ├── document/ │ │ ├── callback_extra_loader.go │ │ ├── callback_extra_transformer.go │ │ ├── doc.go │ │ ├── interface.go │ │ ├── option.go │ │ ├── option_test.go │ │ └── parser/ │ │ ├── doc.go │ │ ├── ext_parser.go │ │ ├── interface.go │ │ ├── option.go │ │ ├── option_test.go │ │ ├── parser_test.go │ │ ├── testdata/ │ │ │ └── test.md │ │ └── text_parser.go │ ├── embedding/ │ │ ├── callback_extra.go │ │ ├── callback_extra_test.go │ │ ├── doc.go │ │ ├── interface.go │ │ ├── option.go │ │ └── option_test.go │ ├── indexer/ │ │ ├── callback_extra.go │ │ ├── callback_extra_test.go │ │ ├── doc.go │ │ ├── interface.go │ │ ├── option.go │ │ └── option_test.go │ ├── model/ │ │ ├── callback_extra.go │ │ ├── callback_extra_test.go │ │ ├── doc.go │ │ ├── interface.go │ │ ├── option.go │ │ └── option_test.go │ ├── prompt/ │ │ ├── callback_extra.go │ │ ├── callback_extra_test.go │ │ ├── chat_template.go │ │ ├── chat_template_test.go │ │ ├── doc.go │ │ ├── interface.go │ │ ├── option.go │ │ └── option_test.go │ ├── retriever/ │ │ ├── callback_extra.go │ │ ├── callback_extra_test.go │ │ ├── doc.go │ │ ├── interface.go │ │ ├── option.go │ │ └── option_test.go │ ├── tool/ │ │ ├── callback_extra.go │ │ ├── callback_extra_test.go │ │ ├── doc.go │ │ ├── interface.go │ │ ├── interrupt.go │ │ ├── interrupt_test.go │ │ ├── option.go │ │ ├── option_test.go │ │ └── utils/ │ │ ├── common.go │ │ ├── common_test.go │ │ ├── create_options.go │ │ ├── doc.go │ │ ├── error_handler.go │ │ ├── error_handler_test.go │ │ ├── invokable_func.go │ │ ├── invokable_func_test.go │ │ ├── streamable_func.go │ │ └── streamable_func_test.go │ └── types.go ├── compose/ │ ├── branch.go │ ├── branch_test.go │ ├── chain.go │ ├── chain_branch.go │ ├── chain_branch_test.go │ ├── chain_parallel.go │ ├── chain_test.go │ ├── checkpoint.go │ ├── checkpoint_migrate_test.go │ ├── checkpoint_test.go │ ├── component_to_graph_node.go │ ├── dag.go │ ├── dag_test.go │ ├── doc.go │ ├── error.go │ ├── error_test.go │ ├── field_mapping.go │ ├── generic_graph.go │ ├── generic_helper.go │ ├── graph.go │ ├── graph_add_node_options.go │ ├── graph_call_options.go │ ├── graph_call_options_test.go │ ├── graph_compile_options.go │ ├── graph_manager.go │ ├── graph_node.go │ ├── graph_run.go │ ├── graph_test.go │ ├── interrupt.go │ ├── introspect.go │ ├── pregel.go │ ├── resume.go │ ├── resume_test.go │ ├── runnable.go │ ├── runnable_test.go │ ├── state.go │ ├── state_test.go │ ├── stream_concat.go │ ├── stream_concat_test.go │ ├── stream_reader.go │ ├── stream_reader_test.go │ ├── tool_node.go │ ├── tool_node_test.go │ ├── types.go │ ├── types_composable.go │ ├── types_lambda.go │ ├── types_lambda_test.go │ ├── utils.go │ ├── utils_test.go │ ├── values_merge.go │ ├── values_merge_test.go │ ├── workflow.go │ └── workflow_test.go ├── doc.go ├── flow/ │ ├── agent/ │ │ ├── agent_option.go │ │ ├── multiagent/ │ │ │ └── host/ │ │ │ ├── callback.go │ │ │ ├── compose.go │ │ │ ├── compose_test.go │ │ │ ├── doc.go │ │ │ ├── options.go │ │ │ └── types.go │ │ ├── react/ │ │ │ ├── callback.go │ │ │ ├── doc.go │ │ │ ├── option.go │ │ │ ├── option_test.go │ │ │ ├── react.go │ │ │ └── react_test.go │ │ └── utils.go │ ├── indexer/ │ │ └── parent/ │ │ ├── parent.go │ │ └── parent_test.go │ └── retriever/ │ ├── multiquery/ │ │ ├── multi_query.go │ │ └── multi_query_test.go │ ├── parent/ │ │ ├── doc.go │ │ ├── parent.go │ │ └── parent_test.go │ ├── router/ │ │ ├── router.go │ │ └── router_test.go │ └── utils/ │ └── utils.go ├── go.mod ├── go.sum ├── internal/ │ ├── callbacks/ │ │ ├── inject.go │ │ ├── interface.go │ │ └── manager.go │ ├── channel.go │ ├── channel_test.go │ ├── concat.go │ ├── concat_test.go │ ├── core/ │ │ ├── address.go │ │ ├── interrupt.go │ │ ├── interrupt_test.go │ │ └── resume.go │ ├── generic/ │ │ ├── generic.go │ │ ├── generic_test.go │ │ ├── type_name.go │ │ └── type_name_test.go │ ├── gmap/ │ │ ├── gmap.go │ │ └── gmap_test.go │ ├── gslice/ │ │ ├── gslice.go │ │ └── gslice_test.go │ ├── merge.go │ ├── mock/ │ │ ├── adk/ │ │ │ └── Agent_mock.go │ │ ├── components/ │ │ │ ├── document/ │ │ │ │ └── document_mock.go │ │ │ ├── embedding/ │ │ │ │ └── Embedding_mock.go │ │ │ ├── indexer/ │ │ │ │ └── indexer_mock.go │ │ │ ├── model/ │ │ │ │ └── ChatModel_mock.go │ │ │ └── retriever/ │ │ │ └── retriever_mock.go │ │ └── doc.go │ ├── safe/ │ │ ├── panic.go │ │ └── panic_test.go │ └── serialization/ │ ├── serialization.go │ └── serialization_test.go ├── llms.txt ├── schema/ │ ├── doc.go │ ├── document.go │ ├── document_test.go │ ├── message.go │ ├── message_parser.go │ ├── message_parser_test.go │ ├── message_test.go │ ├── select.go │ ├── serialization.go │ ├── serialization_test.go │ ├── stream.go │ ├── stream_copy_external_test.go │ ├── stream_test.go │ ├── tool.go │ └── tool_test.go ├── scripts/ │ ├── dev_setup.sh │ └── eino_setup.sh └── utils/ └── callbacks/ ├── template.go └── template_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/.codedev.yml ================================================ coverage: status: project: #add everything under here, more options at https://docs.codecov.com/docs/commit-status default: # default is the status check's name, not default settings target: auto #default threshold: 1% #allow coverage to drop by 1% base: auto if_ci_failed: error #success, failure, error, ignore patch: default: target: 82% #default threshold: 1% #allow coverage to drop by 1% base: auto if_ci_failed: error #success, failure, error, ignore comment: #this is a top-level key layout: " diff, flags, files" behavior: default require_changes: false # if true: only post the comment if coverage changes require_base: false # [true :: must have a base report to post] require_head: true # [true :: must have a head report to post] hide_project_coverage: false # [true :: only show coverage on the git diff aka patch coverage] # sample regex patterns ignore: - "tests" - "examples/" - "mock/" - "callbacks/interface.go" - "utils/safe" - "components/tool/utils/create_options.go" ================================================ FILE: .github/.commit-rules.json ================================================ { "allowedTypes": [ "feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert" ], "allowedScopes": [ "adk", "adk/filesystem", "callbacks", "components", "compose", "deep", "dynamictool", "filesystem", "flow", "internal", "middlewares", "planexecute", "plantask", "prebuilt", "reduction", "schema", "skill", "summarization", "supervisor", "toolsearch", "utils", "docs", "ci", "serialization" ] } ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Version:** Please provide the version of {project_name} you are using. **Environment:** The output of `go env`. **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ #### What type of PR is this? #### Check the PR title. - [ ] This PR title match the format: \(optional scope): \ - [ ] The description of this PR title is user-oriented and clear enough for others to understand. - [ ] Attach the PR updating the user documentation if the current PR requires user awareness at the usage level. [User docs repo](https://github.com/cloudwego/cloudwego.github.io) #### (Optional) Translate the PR title into Chinese. #### (Optional) More detailed description for this PR(en: English/zh: Chinese). en: zh(optional): #### (Optional) Which issue(s) this PR fixes: #### (optional) The PR that updates user documentation: ================================================ FILE: .github/workflows/pr-check.yml ================================================ name: Pull Request Check on: [ pull_request ] jobs: compliant: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Check License Header uses: apache/skywalking-eyes/header@v0.4.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Check Spell uses: crate-ci/typos@v1.42.3 golangci-lint: runs-on: ubuntu-latest permissions: contents: write pull-requests: write repository-projects: write steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: 1.18 # for self-hosted, the cache path is shared across projects, # and it works well without the cache of GitHub actions # Enable it if we're going to use GitHub only cache: true - name: Golang CI Lint # https://golangci-lint.run/ uses: golangci/golangci-lint-action@v9.2.0 with: version: v2.8.0 args: --timeout 5m commit-msg-check: name: Commit Message Check runs-on: ubuntu-latest permissions: contents: read pull-requests: read steps: - uses: actions/checkout@v4 - name: Validate commit messages format and scope uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const fs = require('fs'); const pr = context.payload.pull_request; if (!pr) { core.setFailed('This workflow must run on pull_request events.'); return; } let allowedTypes = []; let allowedScopes = []; try { const raw = fs.readFileSync('.github/.commit-rules.json', 'utf8'); const cfg = JSON.parse(raw); allowedTypes = Array.isArray(cfg.allowedTypes) ? cfg.allowedTypes : []; allowedScopes = Array.isArray(cfg.allowedScopes) ? cfg.allowedScopes : []; } catch (e) { core.setFailed('Cannot read .github/.commit-rules.json: ' + e.message); return; } if (!allowedTypes.length) { core.setFailed('allowedTypes is empty in .github/.commit-rules.json'); return; } const { owner, repo } = context.repo; const pull_number = pr.number; const commits = await github.paginate( github.rest.pulls.listCommits, { owner, repo, pull_number, per_page: 100 } ); let errors = []; for (const c of commits) { const sha = c.sha.slice(0, 7); const subject = (c.commit.message || '').split('\n')[0]; const m = subject.match(/^([a-z]+)(\(([a-z0-9\-\/]+)\))?:\s(.+)$/); if (!m) { errors.push(`(${sha}) invalid format: "${subject}"`); continue; } const type = m[1]; const scope = m[3]; // may be undefined const desc = m[4]; if (!allowedTypes.includes(type)) { errors.push(`(${sha}) invalid type "${type}"`); } if (!desc || !desc.trim()) { errors.push(`(${sha}) description must be non-empty`); } if (scope) { const topScope = scope.split('/')[0]; if (allowedScopes.length && !allowedScopes.includes(topScope)) { errors.push(`(${sha}) invalid scope "${scope}"`); } } } if (errors.length) { core.setFailed('Commit message check failed:\n' + errors.join('\n')); } else { core.info('All commit messages conform to "(optional scope): " and scope rules.'); } pr-title-check: name: PR Title Check runs-on: ubuntu-latest permissions: contents: read pull-requests: read steps: - uses: actions/checkout@v4 - name: Read commit rules id: rules uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const fs = require('fs'); let cfg; try { const raw = fs.readFileSync('.github/.commit-rules.json', 'utf8'); cfg = JSON.parse(raw); } catch (e) { core.setFailed('Cannot read .github/.commit-rules.json: ' + e.message); return; } const toMultiline = (list) => Array.isArray(list) ? list.join('\n') : ''; core.setOutput('types', toMultiline(cfg.allowedTypes)); core.setOutput('scopes', toMultiline(cfg.allowedScopes)); - name: Validate PR title uses: amannn/action-semantic-pull-request@v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: types: ${{ steps.rules.outputs.types }} scopes: ${{ steps.rules.outputs.scopes }} requireScope: false ================================================ FILE: .github/workflows/tag-notification.yml ================================================ name: Tag Notification on: push: tags: - 'v*' jobs: notify: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 fetch-tags: true - name: Fetch tag info run: | git fetch --tags -f - name: Get tag info and send notification run: | # Get the tag name TAG_NAME="${{ github.ref_name }}" echo "Processing tag: $TAG_NAME" # Get tag message echo "Getting tag message..." TAG_MESSAGE=$(git tag -l --format='%(contents)' "$TAG_NAME") echo "Tag message:" echo "$TAG_MESSAGE" echo "---" # Create base content parts HEADER="### 🏷️ Eino New Tag Created: \`$TAG_NAME\`" VERSION_INFO="📦 Version: \`$TAG_NAME\`" # Prepare the message parts for jq if [ ! -z "$TAG_MESSAGE" ]; then # Pass all parts to jq and let it handle the formatting jq -n \ --arg header "$HEADER" \ --arg version "$VERSION_INFO" \ --arg notes "$TAG_MESSAGE" \ --arg repo_url "https://github.com/${{ github.repository }}/releases/tag/$TAG_NAME" \ '{ "msg_type": "interactive", "card": { "elements": [ { "tag": "markdown", "content": ($header + "\n\n" + $version + "\n\n### 📝 Release Notes:\n" + $notes) }, { "tag": "action", "actions": [ { "tag": "button", "text": { "tag": "plain_text", "content": "🔗 View Tag" }, "url": $repo_url, "type": "default" } ] } ], "header": { "title": { "tag": "plain_text", "content": "🏷️ Eino New Tag Created" } } } }' > webhook_payload.json else # Without release notes jq -n \ --arg header "$HEADER" \ --arg version "$VERSION_INFO" \ --arg repo_url "https://github.com/${{ github.repository }}/releases/tag/$TAG_NAME" \ '{ "msg_type": "interactive", "card": { "elements": [ { "tag": "markdown", "content": ($header + "\n\n" + $version) }, { "tag": "action", "actions": [ { "tag": "button", "text": { "tag": "plain_text", "content": "🔗 View Tag" }, "url": $repo_url, "type": "default" } ] } ], "header": { "title": { "tag": "plain_text", "content": "🏷️ Eino New Tag Created" } } } }' > webhook_payload.json fi # Send webhook curl -X POST \ -H "Content-Type: application/json" \ -d @webhook_payload.json \ "${{ secrets.FEISHU_WEBHOOK_URL }}" ================================================ FILE: .github/workflows/tests.yml ================================================ name: Eino Tests on: pull_request: push: branches: - main env: DEFAULT_GO_VERSION: "1.18" jobs: unit-test: name: eino-unit-test runs-on: ubuntu-latest permissions: contents: write pull-requests: write repository-projects: write env: COVERAGE_FILE: coverage.out BREAKDOWN_FILE: main.breakdown steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: ${{ env.DEFAULT_GO_VERSION }} - name: Exec Go Test run: | modules=`find . -name "go.mod" -exec dirname {} \;` echo $modules list="" coverpkg="" if [[ ! -f "go.work" ]];then go work init;fi for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done go work sync go test -race -v -coverprofile=${{ env.COVERAGE_FILE }} -gcflags="all=-l -N" -coverpkg=$coverpkg $list - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: name: eino-unit-test env_vars: GOLANG,EINO files: ${{ env.COVERAGE_FILE }} token: ${{ secrets.CODECOV_TOKEN }} codecov_yml_path: ./github/.codecov.yml benchmark-test: runs-on: ubuntu-latest permissions: contents: write pull-requests: write repository-projects: write steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: ${{ env.DEFAULT_GO_VERSION }} - name: Run Benchmark Tests run: go test -bench=. -benchmem -run=none ./... compatibility-test: strategy: matrix: go: [ "1.19", "1.20", "1.21", "1.22", "1.23", "1.24" ] runs-on: ubuntu-latest permissions: contents: write pull-requests: write repository-projects: write steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} cache: true - name: Compatibility Test run: | # just basic unit test, no coverage report go test -race ./... api-compatibility: name: api-compatibility-check runs-on: ubuntu-latest permissions: contents: write pull-requests: write repository-projects: write if: github.event_name == 'pull_request' steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version: "1.22" - name: Install go-apidiff run: go install github.com/joelanford/go-apidiff@v0.8.2 - name: Check API compatibility id: apidiff run: | BASE_SHA=${{ github.event.pull_request.base.sha }} HEAD_SHA=${{ github.event.pull_request.head.sha }} echo "Checking API compatibility between $BASE_SHA and $HEAD_SHA" go mod tidy if ! DIFF_OUTPUT=$(go-apidiff $BASE_SHA $HEAD_SHA 2>&1); then echo "go-apidiff output: $DIFF_OUTPUT" fi echo "diff_output<> $GITHUB_ENV echo "$DIFF_OUTPUT" >> $GITHUB_ENV echo "EOF" >> $GITHUB_ENV if echo "$DIFF_OUTPUT" | grep -q "Incompatible changes:"; then echo "has_breaking_changes=true" >> $GITHUB_OUTPUT else echo "has_breaking_changes=false" >> $GITHUB_OUTPUT fi - name: Create Review Thread if: steps.apidiff.outputs.has_breaking_changes == 'true' continue-on-error: true uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const reviewComments = await github.rest.pulls.listReviewComments({ owner: context.repo.owner, repo: context.repo.repo, pull_number: context.issue.number }); const existingPackageComments = new Map(); for (const comment of reviewComments.data) { if (comment.body.includes('Breaking API Changes Detected')) { const packageMatch = comment.body.match(/Package: `([^`]+)`/); if (packageMatch) { const pkg = packageMatch[1]; if (!existingPackageComments.has(pkg)) { existingPackageComments.set(pkg, new Set()); } existingPackageComments.get(pkg).add(comment.path); } } } const files = await github.rest.pulls.listFiles({ owner: context.repo.owner, repo: context.repo.repo, pull_number: context.issue.number }); const diffOutput = process.env.diff_output || ''; const breakingChanges = new Map(); let currentPackage = ''; let isInIncompatibleSection = false; const lines = diffOutput.split('\n'); for (let i = 0; i < lines.length; i++) { const line = lines[i].trim(); if (line.startsWith('github.com/')) { currentPackage = line; if (!breakingChanges.has(currentPackage)) { breakingChanges.set(currentPackage, []); } continue; } if (line === 'Incompatible changes:') { isInIncompatibleSection = true; continue; } if (line === '') { isInIncompatibleSection = false; continue; } if (isInIncompatibleSection && line.startsWith('- ')) { const change = line.substring(2); if (currentPackage) { breakingChanges.get(currentPackage).push(change); } } } const changedFiles = files.data; for (const [pkg, changes] of breakingChanges) { if (changes.length === 0) continue; const pkgPath = pkg.split('/').slice(3).join('/'); const matchingFile = changedFiles.find(file => file.filename.includes(pkgPath) ) || changedFiles[0]; const hasCommentForPackage = existingPackageComments.has(pkg) && existingPackageComments.get(pkg).has(matchingFile.filename); if (matchingFile && !hasCommentForPackage) { const changesList = changes.map(change => { const [name, desc] = change.split(':').map(s => s.trim()); return `- **${name}:** ${desc}`; }).join('\n'); const commentBody = [ '🚨 **Breaking API Changes Detected**', '', `Package: \`${pkg}\``, '', 'Incompatible changes:', changesList, '', '
', 'Review Guidelines', '', 'Please ensure that:', '- The changes are absolutely necessary', '- They are properly documented', '- Migration guides are provided if needed', '
', '', '⚠️ Please resolve this thread after reviewing the breaking changes.' ].join('\n'); await github.rest.pulls.createReview({ owner: context.repo.owner, repo: context.repo.repo, pull_number: context.issue.number, event: 'COMMENT', comments: [{ path: matchingFile.filename, position: matchingFile.patch ? matchingFile.patch.split('\n').findIndex(line => line.startsWith('+')) + 1 : 1, body: commentBody }] }); if (!existingPackageComments.has(pkg)) { existingPackageComments.set(pkg, new Set()); } existingPackageComments.get(pkg).add(matchingFile.filename); } } ================================================ FILE: .gitignore ================================================ # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib # Test binary, built with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out # Dependency directories (remove the comment below to include it) # vendor/ # Go workspace file go.work go.work.sum # env file .env # the result of the go build output* output/* # Files generated by IDEs .idea/ *.iml # Vim swap files *.swp # Vscode files .vscode /patches /vendor # Trae files .trae # Specs files (internal documentation) **/specs/ # Reports (generated analysis files) reports/ .DS_Store *.log CLAUDE.md # Specs directories */specs /todos /.claude/ # Internal dev setup (not for public repo) /scripts/dev_setup_internal.sh ================================================ FILE: .golangci.yaml ================================================ # output configuration options version: "2" # All available settings of specific linters. # Refer to https://golangci-lint.run/usage/linters linters: default: standard enable: - revive - godoclint - funlen - cyclop disable: - errcheck - staticcheck - unused - ineffassign exclusions: generated: lax paths: - ".*_test.go" - ".*_mock.go" rules: - path: "^internal/.*" linters: - revive - text: "var-naming: don't use underscores in Go names" linters: - revive - path: "/utils/" text: "var-naming: avoid meaningless package names" linters: - revive - text: "exported: type name will be used as agent.AgentOption by other packages" linters: - revive - path: "adk/prebuilt/deep/task_tool.go" text: "argument-limit: maximum number of arguments per function exceeded" linters: - revive - path: "compose/component_to_graph_node.go" text: "argument-limit: maximum number of arguments per function exceeded" linters: - revive - path: "compose/graph_run.go" text: "argument-limit: maximum number of arguments per function exceeded" linters: - revive - path: "adk/workflow.go" text: "argument-limit: maximum number of arguments per function exceeded" linters: - revive - path: "compose/graph.go" linters: - cyclop text: "calculated cyclomatic complexity for function compile" - path: "schema/message.go" linters: - cyclop text: "calculated cyclomatic complexity for function ConcatMessages" - path: "compose/graph_run.go" linters: - cyclop text: "calculated cyclomatic complexity for function run" - path: "compose/graph.go" linters: - funlen text: "Function 'compile' is too long" - path: "compose/graph_run.go" linters: - funlen text: "Function 'run' is too long" settings: govet: enable-all: true # Disable analyzers by name. # Run `go tool vet help` to see all analyzers. disable: - fieldalignment revive: # Sets the default failure confidence. # This means that linting errors with less than 0.8 confidence will be ignored. # Default: 0.8 confidence: 0.8 rules: # Exported function and methods should have comments. - name: exported severity: error exclude: - "^internal/.*" arguments: - "disable-checks-on-constants" - "disable-checks-on-variables" - "disable-checks-on-types" - "disable-checks-on-methods" - name: package-comments disabled: false - name: var-naming disabled: false arguments: # AllowList - [ "utils", "s_", "err_", "err__", "plan_", "userInput_", "executedSteps_", "executedStep_", "iterator_", "in_", "out_" ] # DenyList - [ ] - - extra-bad-package-names: - helpers - models - name: argument-limit arguments: [ 6 ] - name: function-length arguments: [ 120, 0 ] godoclint: check-exported: true require-package-documentation: true funlen: lines: 200 statements: 120 cyclop: max-complexity: 40 package-average: 20 formatters: enable: - gci - gofmt settings: gofmt: # Simplify code: gofmt with `-s` option. # Default: true simplify: true # Apply the rewrite rules to the source before reformatting. # https://pkg.go.dev/cmd/gofmt # Default: [] rewrite-rules: - pattern: 'interface{}' replacement: 'any' - pattern: 'a[b:len(a)]' replacement: 'a[b:]' gci: # Section configuration to compare against. # Section names are case-insensitive and may contain parameters in (). # The default order of sections is `standard > default > custom > blank > dot > alias > localmodule`. # If `custom-order` is `true`, it follows the order of `sections` option. # Default: ["standard", "default"] sections: - standard - default - localmodule custom-order: true ================================================ FILE: .licenserc.yaml ================================================ header: license: spdx-id: Apache-2.0 copyright-owner: CloudWeGo Authors template: | /* * Copyright {{ .Year }} CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ paths: - '**/*.go' - '**/*.s' comment: on-failure ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at conduct@cloudwego.io. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute ## Your First Pull Request We use GitHub for our codebase. You can start by reading [How To Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests). ## Branch Organization We use [git-flow](https://nvie.com/posts/a-successful-git-branching-model/) as our branch organization, as known as [FDD](https://en.wikipedia.org/wiki/Feature-driven_development) ## Bugs ### 1. How to Find Known Issues We are using [Github Issues](https://github.com/cloudwego/{project_name}/issues) for our public bugs. We keep a close eye on this and try to make it clear when we have an internal fix in progress. Before filing a new task, try to make sure your problem doesn’t already exist. ### 2. Reporting New Issues Providing a reduced test code is a recommended way for reporting issues. Then can place in: - Just in issues - [Golang Playground](https://play.golang.org/) ### 3. Security Bugs Please do not report the safe disclosure of bugs to public issues. Contact us by [Support Email](mailto:conduct@cloudwego.io) ## How to Get in Touch - [Email](mailto:conduct@cloudwego.io) ## Submit a Pull Request Before you submit your Pull Request (PR) consider the following guidelines: 1. Search [GitHub](https://github.com/cloudwego/{project_name}/pulls) for an open or closed PR that relates to your submission. You don't want to duplicate existing efforts. 2. Be sure that an issue describes the problem you're fixing, or documents the design for the feature you'd like to add. Discussing the design upfront helps to ensure that we're ready to accept your work. 3. [Fork](https://docs.github.com/en/github/getting-started-with-github/fork-a-repo) the cloudwego {project_name} repo. 4. In your forked repository, make your changes in a new git branch: ``` git checkout -b my-fix-branch develop ``` 5. Create your patch, including appropriate test cases. 6. Follow our [Style Guides](#code-style-guides). 7. Commit your changes using a descriptive commit message that follows [AngularJS Git Commit Message Conventions](https://docs.google.com/document/d/1QrDFcIiPjSLDn3EL15IJygNPiHORgU1_OOAqWjiDU5Y/edit). Adherence to these conventions is necessary because release notes are automatically generated from these messages. 8. Push your branch to GitHub: ``` git push origin my-fix-branch ``` 9. In GitHub, send a pull request to `{project_name}:develop` ## Contribution Prerequisites - Our development environment keeps up with [Go Official](https://golang.org/project/). - You need fully checking with lint tools before submit your pull request. [gofmt](https://golang.org/pkg/cmd/gofmt/) and [golangci-lint](https://github.com/golangci/golangci-lint) - You are familiar with [GitHub](https://github.com) - Maybe you need familiar with [Actions](https://github.com/features/actions)(our default workflow tool). ## Code Style Guides - [Effective Go](https://golang.org/doc/effective_go) - [Go Code Review Comments](https://github.com/golang/go/wiki/CodeReviewComments) ================================================ FILE: LICENSE-APACHE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS ================================================ FILE: README.md ================================================ # Eino ![coverage](https://raw.githubusercontent.com/cloudwego/eino/badges/.badges/main/coverage.svg) [![Release](https://img.shields.io/github/v/release/cloudwego/eino)](https://github.com/cloudwego/eino/releases) [![WebSite](https://img.shields.io/website?up_message=cloudwego&url=https%3A%2F%2Fwww.cloudwego.io%2F)](https://www.cloudwego.io/) [![License](https://img.shields.io/github/license/cloudwego/eino)](https://github.com/cloudwego/eino/blob/main/LICENSE) [![Go Report Card](https://goreportcard.com/badge/github.com/cloudwego/eino)](https://goreportcard.com/report/github.com/cloudwego/eino) [![OpenIssue](https://img.shields.io/github/issues/cloudwego/eino)](https://github.com/cloudwego/kitex/eino) [![ClosedIssue](https://img.shields.io/github/issues-closed/cloudwego/eino)](https://github.com/cloudwego/eino/issues?q=is%3Aissue+is%3Aclosed) ![Stars](https://img.shields.io/github/stars/cloudwego/eino) ![Forks](https://img.shields.io/github/forks/cloudwego/eino) English | [中文](README.zh_CN.md) # Overview **Eino['aino]** is an LLM application development framework in Golang. It draws from LangChain, Google ADK, and other open-source frameworks, and is designed to follow Golang conventions. Eino provides: - **[Components](https://github.com/cloudwego/eino-ext)**: reusable building blocks like `ChatModel`, `Tool`, `Retriever`, and `ChatTemplate`, with official implementations for OpenAI, Ollama, and more. - **Agent Development Kit (ADK)**: build AI agents with tool use, multi-agent coordination, context management, interrupt/resume for human-in-the-loop, and ready-to-use agent patterns. - **Composition**: connect components into graphs and workflows that can run standalone or be exposed as tools for agents. - **[Examples](https://github.com/cloudwego/eino-examples)**: working code for common patterns and real-world use cases. ![](.github/static/img/eino/eino_concept.jpeg) # Quick Start ## ChatModelAgent Configure a ChatModel, optionally add tools, and you have a working agent: ```Go chatModel, _ := openai.NewChatModel(ctx, &openai.ChatModelConfig{ Model: "gpt-4o", APIKey: os.Getenv("OPENAI_API_KEY"), }) agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Model: chatModel, }) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) iter := runner.Query(ctx, "Hello, who are you?") for { event, ok := iter.Next() if !ok { break } fmt.Println(event.Message.Content) } ``` Add tools to give the agent capabilities: ```Go agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Model: chatModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{weatherTool, calculatorTool}, }, }, }) ``` The agent handles the ReAct loop internally — it decides when to call tools and when to respond. → [ChatModelAgent examples](https://github.com/cloudwego/eino-examples/tree/main/adk/intro) · [docs](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/chat_model/) ## DeepAgent For complex tasks, use DeepAgent. It breaks down problems into steps, delegates to sub-agents, and tracks progress: ```Go deepAgent, _ := deep.New(ctx, &deep.Config{ ChatModel: chatModel, SubAgents: []adk.Agent{researchAgent, codeAgent}, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{shellTool, pythonTool, webSearchTool}, }, }, }) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: deepAgent}) iter := runner.Query(ctx, "Analyze the sales data in report.csv and generate a summary chart") ``` DeepAgent can be configured to coordinate multiple specialized agents, run shell commands, execute Python code, and search the web. → [DeepAgent example](https://github.com/cloudwego/eino-examples/tree/main/adk/multiagent/deep) · [docs](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/deepagents/) ## Composition When you need precise control over execution flow, use `compose` to build graphs and workflows: ```Go graph := compose.NewGraph[*Input, *Output]() graph.AddLambdaNode("validate", validateFn) graph.AddChatModelNode("generate", chatModel) graph.AddLambdaNode("format", formatFn) graph.AddEdge(compose.START, "validate") graph.AddEdge("validate", "generate") graph.AddEdge("generate", "format") graph.AddEdge("format", compose.END) runnable, _ := graph.Compile(ctx) result, _ := runnable.Invoke(ctx, input) ``` Compositions can be exposed as tools for agents, bridging deterministic workflows with autonomous behavior: ```Go tool, _ := graphtool.NewInvokableGraphTool(graph, "data_pipeline", "Process and validate data") agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Model: chatModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{tool}, }, }, }) ``` This lets you build domain-specific pipelines with exact control, then let agents decide when to use them. → [GraphTool examples](https://github.com/cloudwego/eino-examples/tree/main/adk/common/tool/graphtool) · [compose docs](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/) # Key Features ## Component Ecosystem Eino defines component abstractions (ChatModel, Tool, Retriever, Embedding, etc.) with official implementations for OpenAI, Claude, Gemini, Ark, Ollama, Elasticsearch, and more. → [eino-ext](https://github.com/cloudwego/eino-ext) ## Stream Processing Eino automatically handles streaming throughout orchestration: concatenating, boxing, merging, and copying streams as data flows between nodes. Components only implement the streaming paradigms that make sense for them; the framework handles the rest. → [docs](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/stream_programming_essentials/) ## Callback Aspects Inject logging, tracing, and metrics at fixed points (OnStart, OnEnd, OnError, OnStartWithStreamInput, OnEndWithStreamOutput) across components, graphs, and agents. → [docs](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/callback_manual/) ## Interrupt/Resume Any agent or tool can pause execution for human input and resume from checkpoint. The framework handles state persistence and routing. → [docs](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_hitl/) · [examples](https://github.com/cloudwego/eino-examples/tree/main/adk/human-in-the-loop) # Framework Structure ![](.github/static/img/eino/eino_framework.jpeg) The Eino framework consists of: - Eino (this repo): Type definitions, streaming mechanism, component abstractions, orchestration, agent implementations, aspect mechanisms - [EinoExt](https://github.com/cloudwego/eino-ext): Component implementations, callback handlers, usage examples, evaluators, prompt optimizers - [Eino Devops](https://github.com/cloudwego/eino-ext/tree/main/devops): Visualized development and debugging - [EinoExamples](https://github.com/cloudwego/eino-examples): Example applications and best practices ## Documentation - [Eino User Manual](https://www.cloudwego.io/zh/docs/eino/) - [Eino: Quick Start](https://www.cloudwego.io/zh/docs/eino/quick_start/) ## Dependencies - Go 1.18 and above. ## Code Style This repo uses `golangci-lint`. Check locally with: ```bash golangci-lint run ./... ``` Rules enforced: - Exported functions, interfaces, packages, etc. should have GoDoc comments - Code should be formatted with `gofmt -s` - Import order should follow `goimports` (std -> third party -> local) ## Security If you discover a potential security issue, notify Bytedance Security via the [security center](https://security.bytedance.com/src) or [vulnerability reporting email](sec@bytedance.com). Do **not** create a public GitHub issue. ## Contact - Membership: [COMMUNITY MEMBERSHIP](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) - Issues: [Issues](https://github.com/cloudwego/eino/issues) - Lark: Scan the QR code below with [Feishu](https://www.feishu.cn/en/) to join the CloudWeGo/eino user group.     LarkGroup ## License This project is licensed under the [Apache-2.0 License](LICENSE-APACHE). ================================================ FILE: README.zh_CN.md ================================================ # Eino ![coverage](https://raw.githubusercontent.com/cloudwego/eino/badges/.badges/main/coverage.svg) [![Release](https://img.shields.io/github/v/release/cloudwego/eino)](https://github.com/cloudwego/eino/releases) [![WebSite](https://img.shields.io/website?up_message=cloudwego&url=https%3A%2F%2Fwww.cloudwego.io%2F)](https://www.cloudwego.io/) [![License](https://img.shields.io/github/license/cloudwego/eino)](https://github.com/cloudwego/eino/blob/main/LICENSE) [![Go Report Card](https://goreportcard.com/badge/github.com/cloudwego/eino)](https://goreportcard.com/report/github.com/cloudwego/eino) [![OpenIssue](https://img.shields.io/github/issues/cloudwego/eino)](https://github.com/cloudwego/kitex/eino) [![ClosedIssue](https://img.shields.io/github/issues-closed/cloudwego/eino)](https://github.com/cloudwego/eino/issues?q=is%3Aissue+is%3Aclosed) ![Stars](https://img.shields.io/github/stars/cloudwego/eino) ![Forks](https://img.shields.io/github/forks/cloudwego/eino) [English](README.md) | 中文 # 简介 **Eino['aino]** 是一个 Go 语言的 LLM 应用开发框架,借鉴了 LangChain、Google ADK 等开源项目,按照 Go 的惯例设计。 Eino 提供: - **[组件](https://github.com/cloudwego/eino-ext)**:`ChatModel`、`Tool`、`Retriever`、`ChatTemplate` 等可复用模块,官方实现覆盖 OpenAI、Ollama 等 - **智能体开发套件(ADK)**:支持工具调用、多智能体协同、上下文管理、中断/恢复等人机交互,以及开箱即用的智能体模式 - **编排**:把组件组装成图或工作流,既能独立运行,也能作为工具给智能体调用 - **[示例](https://github.com/cloudwego/eino-examples)**:常见模式和实际场景的可运行代码 ![](.github/static/img/eino/eino_concept.jpeg) # 快速上手 ## ChatModelAgent 配置好 ChatModel,加上工具(可选),就能跑起来: ```Go chatModel, _ := openai.NewChatModel(ctx, &openai.ChatModelConfig{ Model: "gpt-4o", APIKey: os.Getenv("OPENAI_API_KEY"), }) agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Model: chatModel, }) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) iter := runner.Query(ctx, "Hello, who are you?") for { event, ok := iter.Next() if !ok { break } fmt.Println(event.Message.Content) } ``` 加工具让智能体有更多能力: ```Go agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Model: chatModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{weatherTool, calculatorTool}, }, }, }) ``` 智能体内部自动处理 ReAct 循环,自己判断什么时候调工具、什么时候回复。 → [ChatModelAgent 示例](https://github.com/cloudwego/eino-examples/tree/main/adk/intro) · [文档](https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/agent_implementation/chat_model/) ## DeepAgent 复杂任务用 DeepAgent,它会把问题拆成步骤,分派给子智能体,并追踪进度: ```Go deepAgent, _ := deep.New(ctx, &deep.Config{ ChatModel: chatModel, SubAgents: []adk.Agent{researchAgent, codeAgent}, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{shellTool, pythonTool, webSearchTool}, }, }, }) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: deepAgent}) iter := runner.Query(ctx, "Analyze the sales data in report.csv and generate a summary chart") ``` DeepAgent 可以配置成:协调多个专业智能体、跑 shell 命令、执行 Python、搜索网络。 → [DeepAgent 示例](https://github.com/cloudwego/eino-examples/tree/main/adk/multiagent/deep) · [文档](https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/agent_implementation/deepagents/) ## 编排 需要精确控制执行流程时,用 `compose` 搭图或工作流: ```Go graph := compose.NewGraph[*Input, *Output]() graph.AddLambdaNode("validate", validateFn) graph.AddChatModelNode("generate", chatModel) graph.AddLambdaNode("format", formatFn) graph.AddEdge(compose.START, "validate") graph.AddEdge("validate", "generate") graph.AddEdge("generate", "format") graph.AddEdge("format", compose.END) runnable, _ := graph.Compile(ctx) result, _ := runnable.Invoke(ctx, input) ``` 编排出来的流程可以包装成工具给智能体用,把确定性流程和自主决策结合起来: ```Go tool, _ := graphtool.NewInvokableGraphTool(graph, "data_pipeline", "Process and validate data") agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Model: chatModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{tool}, }, }, }) ``` 这样你可以写出精确可控的业务流程,再让智能体决定什么时候调用。 → [GraphTool 示例](https://github.com/cloudwego/eino-examples/tree/main/adk/common/tool/graphtool) · [编排文档](https://www.cloudwego.io/zh/docs/eino/core_modules/chain_and_graph_orchestration/) # 主要特性 ## 组件生态 Eino 定义了组件抽象(ChatModel、Tool、Retriever、Embedding 等),官方实现覆盖 OpenAI、Claude、Gemini、Ark、Ollama、Elasticsearch 等。 → [eino-ext](https://github.com/cloudwego/eino-ext) ## 流式处理 Eino 在编排中自动处理流式:拼接、装箱、合并、复制。组件只需实现有业务意义的流式范式,框架处理剩下的。 → [文档](https://www.cloudwego.io/zh/docs/eino/core_modules/chain_and_graph_orchestration/stream_programming_essentials/) ## 回调切面 在固定切点(OnStart、OnEnd、OnError、OnStartWithStreamInput、OnEndWithStreamOutput)注入日志、追踪、指标,适用于组件、图、智能体。 → [文档](https://www.cloudwego.io/zh/docs/eino/core_modules/chain_and_graph_orchestration/callback_manual/) ## 中断/恢复 任何智能体或工具都能暂停等待人工输入,从检查点恢复。框架处理状态持久化和路由。 → [文档](https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/agent_hitl/) · [示例](https://github.com/cloudwego/eino-examples/tree/main/adk/human-in-the-loop) # 框架结构 ![](.github/static/img/eino/eino_framework.jpeg) Eino 框架包含: - Eino(本仓库):类型定义、流处理机制、组件抽象、编排、智能体实现、切面机制 - [EinoExt](https://github.com/cloudwego/eino-ext):组件实现、回调处理器、使用示例、评估器、提示优化器 - [Eino Devops](https://github.com/cloudwego/eino-ext/tree/main/devops):可视化开发和调试 - [EinoExamples](https://github.com/cloudwego/eino-examples):示例应用和最佳实践 ## 文档 - [Eino 用户手册](https://www.cloudwego.io/zh/docs/eino/) - [Eino: 快速开始](https://www.cloudwego.io/zh/docs/eino/quick_start/) ## 依赖 - Go 1.18 及以上 ## 代码规范 本仓库使用 `golangci-lint`,本地检查: ```bash golangci-lint run ./... ``` 规则: - 导出的函数、接口、package 等需要 GoDoc 注释 - 代码格式符合 `gofmt -s` - import 顺序符合 `goimports`(std -> third party -> local) ## 安全 发现安全问题请通过[安全中心](https://security.bytedance.com/src)或[漏洞报告邮箱](sec@bytedance.com)联系字节跳动安全团队。 请**不要**创建公开的 GitHub Issue。 ## 联系我们 - 成为 member:[COMMUNITY MEMBERSHIP](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) - Issues:[Issues](https://github.com/cloudwego/eino/issues) - 飞书:扫码加入 CloudWeGo/eino 用户群     LarkGroup ## 开源许可证 本项目基于 [Apache-2.0 许可证](LICENSE-APACHE) 开源。 ================================================ FILE: _typos.toml ================================================ # Typo check: https://github.com/crate-ci/typos [default] [default.extend-words] Invokable = "Invokable" invokable = "invokable" InvokableLambda = "InvokableLambda" InvokableRun = "InvokableRun" typ = "typ" byted = "byted" cpy = "cpy" mak = "mak" [files] extend-exclude = ["go.mod", "go.sum", "check_branch_name.sh"] ================================================ FILE: adk/agent_tool.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package adk provides core agent development kit utilities and types. package adk import ( "context" "errors" "fmt" "github.com/bytedance/sonic" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) var ( defaultAgentToolParam = schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "request": { Desc: "request to be processed", Required: true, Type: schema.String, }, }) ) type AgentToolOptions struct { fullChatHistoryAsInput bool agentInputSchema *schema.ParamsOneOf } type AgentToolOption func(*AgentToolOptions) // WithFullChatHistoryAsInput enables using the full chat history as input. func WithFullChatHistoryAsInput() AgentToolOption { return func(options *AgentToolOptions) { options.fullChatHistoryAsInput = true } } // WithAgentInputSchema sets a custom input schema for the agent tool. func WithAgentInputSchema(schema *schema.ParamsOneOf) AgentToolOption { return func(options *AgentToolOptions) { options.agentInputSchema = schema } } func withAgentToolEnableStreaming(enabled bool) tool.Option { return tool.WrapImplSpecificOptFn(func(opt *agentToolOptions) { opt.enableStreaming = enabled }) } // NewAgentTool creates a tool that wraps an agent for invocation. // // Event Streaming: // When EmitInternalEvents is enabled in ToolsConfig, the agent tool will emit AgentEvent // from the inner agent to the parent agent's AsyncGenerator, allowing real-time streaming // of the inner agent's output to the end-user via Runner. // // Note that these forwarded events are NOT recorded in the parent agent's runSession. // They are only emitted to the end-user and have no effect on the parent agent's state // or checkpoint. The only exception is Interrupted action, which is propagated via // CompositeInterrupt to enable proper interrupt/resume across agent boundaries. // // Action Scoping: // Actions emitted by the inner agent are scoped to the agent tool boundary: // - Interrupted: Propagated via CompositeInterrupt to allow proper interrupt/resume across boundaries // - Exit, TransferToAgent, BreakLoop: Ignored outside the agent tool; these actions only affect // the inner agent's execution and do not propagate to the parent agent // // This scoping ensures that nested agents cannot unexpectedly terminate or transfer control // of their parent agent's execution flow. func NewAgentTool(_ context.Context, agent Agent, options ...AgentToolOption) tool.BaseTool { opts := &AgentToolOptions{} for _, opt := range options { opt(opts) } return &agentTool{ agent: agent, fullChatHistoryAsInput: opts.fullChatHistoryAsInput, inputSchema: opts.agentInputSchema, } } type agentTool struct { agent Agent fullChatHistoryAsInput bool inputSchema *schema.ParamsOneOf } func (at *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) { param := at.inputSchema if param == nil { param = defaultAgentToolParam } return &schema.ToolInfo{ Name: at.agent.Name(ctx), Desc: at.agent.Description(ctx), ParamsOneOf: param, }, nil } func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { gen, enableStreaming := getEmitGeneratorAndEnableStreaming(opts) var ms *bridgeStore var iter *AsyncIterator[*AgentEvent] var err error wasInterrupted, hasState, state := tool.GetInterruptState[[]byte](ctx) if !wasInterrupted { ms = newBridgeStore() var input []Message if at.fullChatHistoryAsInput { input, err = getReactChatHistory(ctx, at.agent.Name(ctx)) if err != nil { return "", err } } else { if at.inputSchema == nil { // default input schema type request struct { Request string `json:"request"` } req := &request{} err = sonic.UnmarshalString(argumentsInJSON, req) if err != nil { return "", err } argumentsInJSON = req.Request } input = []Message{ schema.UserMessage(argumentsInJSON), } } iter = newInvokableAgentToolRunner(at.agent, ms, enableStreaming).Run(ctx, input, append(getOptionsByAgentName(at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) } else { if !hasState { return "", fmt.Errorf("agent tool '%s' interrupt has happened, but cannot find interrupt state", at.agent.Name(ctx)) } ms = newResumeBridgeStore(state) iter, err = newInvokableAgentToolRunner(at.agent, ms, enableStreaming). Resume(ctx, bridgeCheckpointID, append(getOptionsByAgentName(at.agent.Name(ctx), opts), withSharedParentSession())...) if err != nil { return "", err } } var lastEvent *AgentEvent for { event, ok := iter.Next() if !ok { break } if lastEvent != nil && lastEvent.Output != nil && lastEvent.Output.MessageOutput != nil && lastEvent.Output.MessageOutput.MessageStream != nil { lastEvent.Output.MessageOutput.MessageStream.Close() } if event.Err != nil { return "", event.Err } if gen != nil { if event.Action == nil || event.Action.Interrupted == nil { if parentRunCtx := getRunCtx(ctx); parentRunCtx != nil && len(parentRunCtx.RunPath) > 0 { rp := make([]RunStep, 0, len(parentRunCtx.RunPath)+len(event.RunPath)) rp = append(rp, parentRunCtx.RunPath...) rp = append(rp, event.RunPath...) event.RunPath = rp } tmp := copyAgentEvent(event) gen.Send(event) event = tmp } } lastEvent = event } if lastEvent != nil && lastEvent.Action != nil && lastEvent.Action.Interrupted != nil { data, existed, err_ := ms.Get(ctx, bridgeCheckpointID) if err_ != nil { return "", fmt.Errorf("failed to get interrupt info: %w", err_) } if !existed { return "", fmt.Errorf("interrupt has happened, but cannot find interrupt info") } return "", tool.CompositeInterrupt(ctx, "agent tool interrupt", data, lastEvent.Action.internalInterrupted) } if lastEvent == nil { return "", errors.New("no event returned") } var ret string if lastEvent.Output != nil { if output := lastEvent.Output.MessageOutput; output != nil { msg, err := output.GetMessage() if err != nil { return "", err } ret = msg.Content } } return ret, nil } // agentToolOptions is a wrapper structure used to convert AgentRunOption slices to tool.Option. // It stores the agent name and corresponding run options for tool-specific processing. type agentToolOptions struct { agentName string opts []AgentRunOption generator *AsyncGenerator[*AgentEvent] enableStreaming bool } func withAgentToolOptions(agentName string, opts []AgentRunOption) tool.Option { return tool.WrapImplSpecificOptFn(func(opt *agentToolOptions) { opt.agentName = agentName opt.opts = opts }) } func withAgentToolEventGenerator(gen *AsyncGenerator[*AgentEvent]) tool.Option { return tool.WrapImplSpecificOptFn(func(o *agentToolOptions) { o.generator = gen }) } func getOptionsByAgentName(agentName string, opts []tool.Option) []AgentRunOption { var ret []AgentRunOption for _, opt := range opts { o := tool.GetImplSpecificOptions[agentToolOptions](nil, opt) if o != nil && o.agentName == agentName { ret = append(ret, o.opts...) } } return ret } func getEmitGeneratorAndEnableStreaming(opts []tool.Option) (*AsyncGenerator[*AgentEvent], bool) { o := tool.GetImplSpecificOptions[agentToolOptions](nil, opts...) if o == nil { return nil, false } return o.generator, o.enableStreaming } func getReactChatHistory(ctx context.Context, destAgentName string) ([]Message, error) { var messages []Message err := compose.ProcessState(ctx, func(ctx context.Context, st *State) error { messages = make([]Message, len(st.Messages)-1) copy(messages, st.Messages[:len(st.Messages)-1]) // remove the last assistant message, which is the tool call message return nil }) if err != nil { return nil, fmt.Errorf("failed to get chat history from state: %w", err) } var agentName string if runCtx := getRunCtx(ctx); runCtx != nil && len(runCtx.RunPath) > 0 { agentName = runCtx.RunPath[len(runCtx.RunPath)-1].agentName } a, t := GenTransferMessages(ctx, destAgentName) messages = append(messages, a, t) history := make([]Message, 0, len(messages)) for _, msg := range messages { if msg.Role == schema.System { continue } if msg.Role == schema.Assistant || msg.Role == schema.Tool { msg = rewriteMessage(msg, agentName) } history = append(history, msg) } return history, nil } func newInvokableAgentToolRunner(agent Agent, store compose.CheckPointStore, enableStreaming bool) *Runner { return &Runner{ a: agent, enableStreaming: enableStreaming, store: store, } } ================================================ FILE: adk/agent_tool_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "fmt" "strings" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // mockAgent implements the Agent interface for testing type mockAgentForTool struct { name string description string responses []*AgentEvent } func (a *mockAgentForTool) Name(_ context.Context) string { return a.name } func (a *mockAgentForTool) Description(_ context.Context) string { return a.description } func (a *mockAgentForTool) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { defer generator.Close() for _, event := range a.responses { generator.Send(event) // If the event has an Exit action, stop sending events if event.Action != nil && event.Action.Exit { break } } }() return iterator } func newMockAgentForTool(name, description string, responses []*AgentEvent) *mockAgentForTool { return &mockAgentForTool{ name: name, description: description, responses: responses, } } func TestAgentTool_Info(t *testing.T) { // Create a mock agent mockAgent_ := newMockAgentForTool("TestAgent", "Test agent description", nil) // Create an agentTool with the mock agent agentTool_ := NewAgentTool(context.Background(), mockAgent_) // Test the Info method ctx := context.Background() info, err := agentTool_.Info(ctx) // Verify results assert.NoError(t, err) assert.NotNil(t, info) assert.Equal(t, "TestAgent", info.Name) assert.Equal(t, "Test agent description", info.Desc) assert.NotNil(t, info.ParamsOneOf) } func TestAgentTool_SharedParentSessionValues(t *testing.T) { ctx := context.Background() inner := &sessionValuesAgent{name: "inner"} innerTool := NewAgentTool(ctx, inner).(tool.InvokableTool) input := &AgentInput{Messages: []Message{schema.UserMessage("q")}} ctx, _ = initRunCtx(ctx, "outer", input) AddSessionValue(ctx, "parent_key", "parent_val") parentSession := getRunCtx(ctx).Session _, err := innerTool.InvokableRun(ctx, `{"request":"hello"}`) assert.NoError(t, err) assert.Equal(t, "parent_val", inner.seenParentValue) assert.NotNil(t, inner.capturedSession) assert.NotSame(t, parentSession, inner.capturedSession) assert.NotNil(t, parentSession.valuesMtx) assert.Same(t, parentSession.valuesMtx, inner.capturedSession.valuesMtx) mtx := parentSession.valuesMtx mtx.Lock() inner.capturedSession.Values["direct_child_key"] = "direct_child_val" mtx.Unlock() mtx.Lock() v2, ok2 := parentSession.Values["direct_child_key"] mtx.Unlock() assert.True(t, ok2) assert.Equal(t, "direct_child_val", v2) mtx.Lock() parentSession.Values["direct_parent_key"] = "direct_parent_val" mtx.Unlock() mtx.Lock() v3, ok3 := inner.capturedSession.Values["direct_parent_key"] mtx.Unlock() assert.True(t, ok3) assert.Equal(t, "direct_parent_val", v3) v, ok := GetSessionValue(ctx, "child_key") assert.True(t, ok) assert.Equal(t, "child_val", v) } type sessionValuesAgent struct { name string seenParentValue any capturedSession *runSession } func (a *sessionValuesAgent) Name(context.Context) string { return a.name } func (a *sessionValuesAgent) Description(context.Context) string { return "test" } func (a *sessionValuesAgent) Run(ctx context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { if rc := getRunCtx(ctx); rc != nil { a.capturedSession = rc.Session } a.seenParentValue, _ = GetSessionValue(ctx, "parent_key") AddSessionValue(ctx, "child_key", "child_val") it, gen := NewAsyncIteratorPair[*AgentEvent]() gen.Send(&AgentEvent{ AgentName: a.name, Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("ok", nil), Role: schema.Assistant, }, }, }) gen.Close() return it } func TestAgentTool_InvokableRun(t *testing.T) { // Create a context ctx := context.Background() // Test cases tests := []struct { name string agentResponses []*AgentEvent request string expectedOutput string expectError bool }{ { name: "successful model response", agentResponses: []*AgentEvent{ { AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Test response", nil), Role: schema.Assistant, }, }, }, }, request: `{"request":"Test request"}`, expectedOutput: "Test response", expectError: false, }, { name: "successful tool call response", agentResponses: []*AgentEvent{ { AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.ToolMessage("Tool response", "test-id"), Role: schema.Tool, }, }, }, }, request: `{"request":"Test tool request"}`, expectedOutput: "Tool response", expectError: false, }, { name: "invalid request JSON", agentResponses: nil, request: `invalid json`, expectedOutput: "", expectError: true, }, { name: "no events returned", agentResponses: []*AgentEvent{}, request: `{"request":"Test request"}`, expectedOutput: "", expectError: true, }, { name: "error in event", agentResponses: []*AgentEvent{ { AgentName: "TestAgent", Err: assert.AnError, }, }, request: `{"request":"Test request"}`, expectedOutput: "", expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create a mock agent with the test responses mockAgent_ := newMockAgentForTool("TestAgent", "Test agent description", tt.agentResponses) // Create an agentTool with the mock agent agentTool_ := NewAgentTool(ctx, mockAgent_) // Call InvokableRun output, err := agentTool_.(tool.InvokableTool).InvokableRun(ctx, tt.request) // Verify results if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.expectedOutput, output) } }) } } func TestGetReactHistory(t *testing.T) { g := compose.NewGraph[string, []Message](compose.WithGenLocalState(func(ctx context.Context) (state *State) { return &State{ Messages: []Message{ schema.UserMessage("user query"), schema.AssistantMessage("", []schema.ToolCall{{ID: "tool call id 1", Function: schema.FunctionCall{Name: "tool1", Arguments: "arguments1"}}}), schema.ToolMessage("tool result 1", "tool call id 1", schema.WithToolName("tool1")), schema.AssistantMessage("", []schema.ToolCall{{ID: "tool call id 2", Function: schema.FunctionCall{Name: "tool2", Arguments: "arguments2"}}}), }, } })) assert.NoError(t, g.AddLambdaNode("1", compose.InvokableLambda(func(ctx context.Context, input string) (output []Message, err error) { return getReactChatHistory(ctx, "DestAgentName") }))) assert.NoError(t, g.AddEdge(compose.START, "1")) assert.NoError(t, g.AddEdge("1", compose.END)) ctx := context.Background() ctx, _ = initRunCtx(ctx, "MyAgent", nil) runner, err := g.Compile(ctx) assert.NoError(t, err) result, err := runner.Invoke(ctx, "") assert.NoError(t, err) assert.Equal(t, []Message{ schema.UserMessage("user query"), schema.UserMessage("For context: [MyAgent] called tool: `tool1` with arguments: arguments1."), schema.UserMessage("For context: [MyAgent] `tool1` tool returned result: tool result 1."), schema.UserMessage("For context: [MyAgent] called tool: `transfer_to_agent` with arguments: DestAgentName."), schema.UserMessage("For context: [MyAgent] `transfer_to_agent` tool returned result: successfully transferred to agent [DestAgentName]."), }, result) } // mockAgentWithInputCapture implements the Agent interface for testing and captures the input it receives type mockAgentWithInputCapture struct { name string description string capturedInput []Message responses []*AgentEvent } func (a *mockAgentWithInputCapture) Name(_ context.Context) string { return a.name } func (a *mockAgentWithInputCapture) Description(_ context.Context) string { return a.description } func (a *mockAgentWithInputCapture) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { a.capturedInput = input.Messages iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { defer generator.Close() for _, event := range a.responses { generator.Send(event) // If the event has an Exit action, stop sending events if event.Action != nil && event.Action.Exit { break } } }() return iterator } func newMockAgentWithInputCapture(name, description string, responses []*AgentEvent) *mockAgentWithInputCapture { return &mockAgentWithInputCapture{ name: name, description: description, responses: responses, } } func TestAgentToolWithOptions(t *testing.T) { // Test Case 1: WithFullChatHistoryAsInput t.Run("WithFullChatHistoryAsInput", func(t *testing.T) { ctx := context.Background() // 1. Set up a mock agent that will capture the input it receives mockAgent := newMockAgentWithInputCapture("test-agent", "a test agent", []*AgentEvent{ { AgentName: "test-agent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("done", nil), Role: schema.Assistant, }, }, }, }) // 2. Create an agentTool with the option agentTool := NewAgentTool(ctx, mockAgent, WithFullChatHistoryAsInput()) // 3. Set up a context with a chat history using a graph history := []Message{ schema.UserMessage("first user message"), schema.AssistantMessage("first assistant response", nil), } g := compose.NewGraph[string, string](compose.WithGenLocalState(func(ctx context.Context) (state *State) { return &State{ Messages: append(history, schema.AssistantMessage("tool call msg", nil)), } })) assert.NoError(t, g.AddLambdaNode("1", compose.InvokableLambda(func(ctx context.Context, input string) (output string, err error) { // Run the tool within the graph context that has the state _, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"request":"some ignored input"}`) return "done", err }))) assert.NoError(t, g.AddEdge(compose.START, "1")) assert.NoError(t, g.AddEdge("1", compose.END)) ctx, _ = initRunCtx(ctx, "react-agent", nil) runner, err := g.Compile(ctx) assert.NoError(t, err) // 4. Run the graph which will execute the tool with the state _, err = runner.Invoke(ctx, "") assert.NoError(t, err) // 5. Assert that the agent received the full history // The agent should receive: history (minus last assistant message) + transfer messages assert.Len(t, mockAgent.capturedInput, 4) // 2 from history + 2 transfer messages assert.Equal(t, "first user message", mockAgent.capturedInput[0].Content) assert.Equal(t, "For context: [react-agent] said: first assistant response.", mockAgent.capturedInput[1].Content) assert.Equal(t, "For context: [react-agent] called tool: `transfer_to_agent` with arguments: test-agent.", mockAgent.capturedInput[2].Content) assert.Equal(t, "For context: [react-agent] `transfer_to_agent` tool returned result: successfully transferred to agent [test-agent].", mockAgent.capturedInput[3].Content) }) // Test Case 2: WithAgentInputSchema t.Run("WithAgentInputSchema", func(t *testing.T) { ctx := context.Background() // 1. Define a custom schema customSchema := schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "custom_arg": { Desc: "a custom argument", Required: true, Type: schema.String, }, }) // 2. Set up a mock agent to capture input mockAgent := newMockAgentWithInputCapture("schema-agent", "agent with custom schema", []*AgentEvent{ { AgentName: "schema-agent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("schema processed", nil), Role: schema.Assistant, }, }, }, }) // 3. Create agentTool with the custom schema option agentTool := NewAgentTool(ctx, mockAgent, WithAgentInputSchema(customSchema)) // 4. Verify the Info() method returns the custom schema info, err := agentTool.Info(ctx) assert.NoError(t, err) assert.Equal(t, customSchema, info.ParamsOneOf) // 5. Run the tool with arguments matching the custom schema _, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"custom_arg":"hello world"}`) assert.NoError(t, err) // 6. Assert that the agent received the correctly parsed argument // With custom schema, the agent should receive the raw JSON as input assert.Len(t, mockAgent.capturedInput, 1) assert.Equal(t, `{"custom_arg":"hello world"}`, mockAgent.capturedInput[0].Content) }) // Test Case 3: WithAgentInputSchema with complex schema t.Run("WithAgentInputSchema_ComplexSchema", func(t *testing.T) { ctx := context.Background() // 1. Define a complex custom schema with multiple parameters complexSchema := schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "name": { Desc: "user name", Required: true, Type: schema.String, }, "age": { Desc: "user age", Required: false, Type: schema.Integer, }, "active": { Desc: "user status", Required: false, Type: schema.Boolean, }, }) // 2. Set up a mock agent mockAgent := newMockAgentWithInputCapture("complex-agent", "agent with complex schema", []*AgentEvent{ { AgentName: "complex-agent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("complex processed", nil), Role: schema.Assistant, }, }, }, }) // 3. Create agentTool with the complex schema option agentTool := NewAgentTool(ctx, mockAgent, WithAgentInputSchema(complexSchema)) // 4. Verify the Info() method returns the complex schema info, err := agentTool.Info(ctx) assert.NoError(t, err) assert.Equal(t, complexSchema, info.ParamsOneOf) // 5. Run the tool with complex arguments _, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"name":"John","age":30,"active":true}`) assert.NoError(t, err) // 6. Assert that the agent received the complex JSON assert.Len(t, mockAgent.capturedInput, 1) assert.Equal(t, `{"name":"John","age":30,"active":true}`, mockAgent.capturedInput[0].Content) }) // Test Case 4: Both options together t.Run("BothOptionsTogether", func(t *testing.T) { ctx := context.Background() // 1. Define a custom schema customSchema := schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "query": { Desc: "search query", Required: true, Type: schema.String, }, }) // 2. Set up a mock agent mockAgent := newMockAgentWithInputCapture("combined-agent", "agent with both options", []*AgentEvent{ { AgentName: "combined-agent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("combined processed", nil), Role: schema.Assistant, }, }, }, }) // 3. Create agentTool with both options agentTool := NewAgentTool(ctx, mockAgent, WithAgentInputSchema(customSchema), WithFullChatHistoryAsInput()) // 4. Set up a context with chat history using a graph history := []Message{ schema.UserMessage("previous conversation"), schema.AssistantMessage("previous response", nil), } g := compose.NewGraph[string, string](compose.WithGenLocalState(func(ctx context.Context) (state *State) { return &State{ Messages: append(history, schema.AssistantMessage("tool call", nil)), } })) assert.NoError(t, g.AddLambdaNode("1", compose.InvokableLambda(func(ctx context.Context, input string) (output string, err error) { // Run the tool within the graph context that has the state _, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"query":"current query"}`) return "done", err }))) assert.NoError(t, g.AddEdge(compose.START, "1")) assert.NoError(t, g.AddEdge("1", compose.END)) ctx, _ = initRunCtx(ctx, "react-agent", nil) runner, err := g.Compile(ctx) assert.NoError(t, err) // 5. Run the graph which will execute the tool with the state _, err = runner.Invoke(ctx, "") assert.NoError(t, err) // 6. Verify both options work together info, err := agentTool.Info(ctx) assert.NoError(t, err) assert.Equal(t, customSchema, info.ParamsOneOf) // The agent should receive full history + the custom query assert.Len(t, mockAgent.capturedInput, 4) // 2 history + 2 transfer messages assert.Equal(t, "previous conversation", mockAgent.capturedInput[0].Content) assert.Equal(t, "For context: [react-agent] said: previous response.", mockAgent.capturedInput[1].Content) assert.Equal(t, "For context: [react-agent] called tool: `transfer_to_agent` with arguments: combined-agent.", mockAgent.capturedInput[2].Content) assert.Equal(t, "For context: [react-agent] `transfer_to_agent` tool returned result: successfully transferred to agent [combined-agent].", mockAgent.capturedInput[3].Content) }) } type fakeTCM struct{} func (f *fakeTCM) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { o := model.GetCommonOptions(&model.Options{}, opts...) tc := schema.ToolCall{ID: "id-1", Type: "function"} if len(o.Tools) > 0 { tc.Function.Name = o.Tools[0].Name } tc.Function.Arguments = `{"request":"hello"}` return schema.AssistantMessage("", []schema.ToolCall{tc}), nil } func (f *fakeTCM) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { msg, _ := f.Generate(ctx, input, opts...) return schema.StreamReaderFromArray([]*schema.Message{msg}), nil } func (f *fakeTCM) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { return f, nil } type emitOnceModel struct{} func (e *emitOnceModel) Generate(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { return schema.AssistantMessage("inner2", nil), nil } func (e *emitOnceModel) Stream(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { m, _ := e.Generate(ctx, input) return schema.StreamReaderFromArray([]*schema.Message{m}), nil } func (e *emitOnceModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { return e, nil } type emitEventsAgent struct{ events []*AgentEvent } func (e *emitEventsAgent) Name(context.Context) string { return "emit" } func (e *emitEventsAgent) Description(context.Context) string { return "test" } func (e *emitEventsAgent) Run(context.Context, *AgentInput, ...AgentRunOption) *AsyncIterator[*AgentEvent] { it, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { for _, ev := range e.events { gen.Send(ev) } gen.Close() }() return it } // spyAgent captures runSession from ctx in a single nested run type spyAgent struct { a Agent mu sync.Mutex captured *runSession } func (s *spyAgent) Name(ctx context.Context) string { return s.a.Name(ctx) } func (s *spyAgent) Description(ctx context.Context) string { return s.a.Description(ctx) } func (s *spyAgent) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { if rc := getRunCtx(ctx); rc != nil { s.mu.Lock() s.captured = rc.Session s.mu.Unlock() } return s.a.Run(ctx, input, options...) } func (s *spyAgent) getCaptured() *runSession { s.mu.Lock() defer s.mu.Unlock() return s.captured } func TestNestedAgentTool_RunPath(t *testing.T) { ctx := context.Background() inner2, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner2", Description: "leaf", Model: &emitOnceModel{}, ToolsConfig: ToolsConfig{EmitInternalEvents: true}, }) inner2Spy := &spyAgent{a: inner2} inner2Tool := NewAgentTool(ctx, inner2Spy) inner, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner", Description: "mid", Model: &fakeTCM{}, ToolsConfig: ToolsConfig{EmitInternalEvents: true, ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{inner2Tool}}}, }) innerSpy := &spyAgent{a: inner} innerTool := NewAgentTool(ctx, innerSpy) outer, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "outer", Description: "top", Model: &fakeTCM{}, ToolsConfig: ToolsConfig{EmitInternalEvents: true, ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{innerTool}}}, }) input := &AgentInput{Messages: []Message{schema.UserMessage("q")}} ctx, outerRunCtx := initRunCtx(ctx, "outer", input) r := NewRunner(ctx, RunnerConfig{Agent: outer, EnableStreaming: false, CheckPointStore: newBridgeStore()}) it := r.Run(ctx, []Message{schema.UserMessage("q")}) var target *AgentEvent for { ev, ok := it.Next() if !ok { break } if ev.Output != nil && ev.Output.MessageOutput != nil && !ev.Output.MessageOutput.IsStreaming { if ev.Output.MessageOutput.Message != nil && ev.Output.MessageOutput.Message.Content == "inner2" { target = ev break } } } if target == nil { t.Fatalf("no inner2 event found in ephemerals") } got := make([]string, len(target.RunPath)) for i := range target.RunPath { got[i] = target.RunPath[i].agentName } want := []string{"outer", "inner", "inner2"} if len(got) != len(want) { t.Fatalf("unexpected runPath len: got %d want %d: %+v", len(got), len(want), got) } for i := range want { if got[i] != want[i] { t.Fatalf("runPath mismatch at %d: got %s want %s; full: %+v", i, got[i], want[i], got) } } for _, w := range outerRunCtx.Session.getEvents() { if w.AgentName != "outer" { t.Fatalf("outer session contains non-outer event: %s", w.AgentName) } } if innerSpy.getCaptured() == nil { t.Fatalf("inner spy did not capture session") } for _, w := range innerSpy.getCaptured().getEvents() { if w.AgentName != "inner" { t.Fatalf("inner session contains non-inner event: %s", w.AgentName) } } if inner2Spy.getCaptured() == nil { t.Fatalf("inner2 spy did not capture session") } for _, w := range inner2Spy.getCaptured().getEvents() { if w.AgentName != "inner2" { t.Fatalf("inner2 session contains non-inner2 event: %s", w.AgentName) } } } func TestNestedAgentTool_NoInternalEventsWhenDisabled(t *testing.T) { ctx := context.Background() inner2, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner2", Description: "leaf", Model: &emitOnceModel{}, ToolsConfig: ToolsConfig{EmitInternalEvents: false}, }) inner2Tool := NewAgentTool(ctx, inner2) inner, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner", Description: "mid", Model: &fakeTCM{}, ToolsConfig: ToolsConfig{EmitInternalEvents: false, ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{inner2Tool}}}, }) innerTool := NewAgentTool(ctx, inner) outer, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "outer", Description: "top", Model: &fakeTCM{}, ToolsConfig: ToolsConfig{EmitInternalEvents: false, ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{innerTool}}}, }) r := NewRunner(ctx, RunnerConfig{Agent: outer, EnableStreaming: false, CheckPointStore: newBridgeStore()}) it := r.Run(ctx, []Message{schema.UserMessage("q")}) for { ev, ok := it.Next() if !ok { break } if ev.AgentName == "inner2" { t.Fatalf("inner2 internal event should not be emitted when disabled") } } } func TestNestedAgentTool_InnerToolResultNotEmittedToOuter(t *testing.T) { ctx := context.Background() innerTool := &simpleTool{name: "inner_tool", result: "inner_tool_result"} inner, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner", Description: "inner agent with tool", Model: &fakeTCM{}, ToolsConfig: ToolsConfig{ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{innerTool}}}, }) innerAgentTool := NewAgentTool(ctx, inner) outer, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "outer", Description: "outer agent", Model: &fakeTCM{}, ToolsConfig: ToolsConfig{ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{innerAgentTool}}}, }) r := NewRunner(ctx, RunnerConfig{Agent: outer, EnableStreaming: false, CheckPointStore: newBridgeStore()}) it := r.Run(ctx, []Message{schema.UserMessage("q")}) var allEvents []*AgentEvent for { ev, ok := it.Next() if !ok { break } allEvents = append(allEvents, ev) } for _, ev := range allEvents { if ev.Output != nil && ev.Output.MessageOutput != nil && ev.Output.MessageOutput.Message != nil && ev.Output.MessageOutput.Message.Role == schema.Tool && ev.AgentName == "outer" && ev.Output.MessageOutput.Message.Content == "inner_tool_result" { t.Fatalf("inner agent's tool result (inner_tool_result) should not be emitted as outer agent's event, but got event with AgentName=%s, Content=%s", ev.AgentName, ev.Output.MessageOutput.Message.Content) } } } type simpleTool struct { name string result string } func (s *simpleTool) Info(context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: s.name, Desc: "simple tool"}, nil } func (s *simpleTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { return s.result, nil } func TestAgentTool_InterruptWithoutCheckpoint(t *testing.T) { ctx := context.Background() ctx, _ = initRunCtx(ctx, "TestAgent", &AgentInput{Messages: []Message{}}) interrupted := &AgentEvent{AgentName: "TestAgent"} interrupted.Action = StatefulInterrupt(ctx, "info", "state").Action err := compositeInterruptFromLast(ctx, &bridgeStore{}, interrupted) if err == nil { t.Fatalf("expected error for interrupt without checkpoint") } if !strings.Contains(err.Error(), "interrupt occurred but checkpoint data is missing") { t.Fatalf("unexpected error: %v", err) } } func compositeInterruptFromLast(ctx context.Context, ms *bridgeStore, lastEvent *AgentEvent) error { if lastEvent == nil || lastEvent.Action == nil || lastEvent.Action.Interrupted == nil { return nil } data, existed, err := ms.Get(ctx, bridgeCheckpointID) if err != nil { return fmt.Errorf("failed to get interrupt info: %w", err) } if !existed { return fmt.Errorf("interrupt occurred but checkpoint data is missing") } return tool.CompositeInterrupt(ctx, "agent tool interrupt", data, lastEvent.Action.internalInterrupted) } func TestAgentTool_InvokableRun_FinalOnly(t *testing.T) { ctx := context.Background() inner2, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner2", Description: "leaf", Model: &emitOnceModel{}, ToolsConfig: ToolsConfig{EmitInternalEvents: true}, }) invTool := NewAgentTool(ctx, inner2) out, err := invTool.(tool.InvokableTool).InvokableRun(ctx, `{"request":"q"}`) if err != nil { t.Fatalf("invokable run error: %v", err) } if out != "inner2" { t.Fatalf("unexpected output: %s", out) } } type streamingAgent struct{} func (s *streamingAgent) Name(context.Context) string { return "stream" } func (s *streamingAgent) Description(context.Context) string { return "test" } func (s *streamingAgent) Run(context.Context, *AgentInput, ...AgentRunOption) *AsyncIterator[*AgentEvent] { it, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { mv := &MessageVariant{IsStreaming: true, MessageStream: schema.StreamReaderFromArray([]Message{schema.AssistantMessage("1", nil), schema.AssistantMessage("2", nil)})} gen.Send(&AgentEvent{AgentName: "stream", Output: &AgentOutput{MessageOutput: mv}}) mv = &MessageVariant{IsStreaming: true, MessageStream: schema.StreamReaderFromArray([]Message{schema.AssistantMessage("a", nil), schema.AssistantMessage("b", nil)})} gen.Send(&AgentEvent{AgentName: "stream", Output: &AgentOutput{MessageOutput: mv}}) gen.Close() }() return it } func TestAgentTool_InvokableRun_StreamingVariant(t *testing.T) { ctx := context.Background() agent := &streamingAgent{} it := NewAgentTool(ctx, agent) out, err := it.(tool.InvokableTool).InvokableRun(ctx, `{"request":"q"}`) if err != nil { t.Fatalf("invokable run error: %v", err) } if out != "ab" { t.Fatalf("unexpected output: %s", out) } } func TestSequentialWorkflow_WithChatModelAgentTool_NestedRunPathAndSessions(t *testing.T) { ctx := context.Background() inner2, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner2", Description: "leaf", Model: &emitOnceModel{}, ToolsConfig: ToolsConfig{EmitInternalEvents: true}, }) inner2Spy := &spyAgent{a: inner2} inner2ToolSpy := NewAgentTool(ctx, inner2Spy) innerWithSpy, _ := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "inner", Description: "mid", Model: &fakeTCM{}, ToolsConfig: ToolsConfig{EmitInternalEvents: true, ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{inner2ToolSpy}}}, }) innerSpy := &spyAgent{a: innerWithSpy} outer, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "outer-seq", Description: "workflow", SubAgents: []Agent{innerSpy}, }) if err != nil { t.Fatalf("new sequential agent err: %v", err) } input := &AgentInput{Messages: []Message{schema.UserMessage("q")}} ctx, outerRunCtx := initRunCtx(ctx, "outer-seq", input) r := NewRunner(ctx, RunnerConfig{Agent: outer, EnableStreaming: false, CheckPointStore: newBridgeStore()}) it := r.Run(ctx, []Message{schema.UserMessage("q")}) var target *AgentEvent for { ev, ok := it.Next() if !ok { break } if ev.Output != nil && ev.Output.MessageOutput != nil && !ev.Output.MessageOutput.IsStreaming { if ev.Output.MessageOutput.Message != nil && ev.Output.MessageOutput.Message.Content == "inner2" { target = ev break } } } if target == nil { t.Fatalf("no inner2 event found") } got := make([]string, len(target.RunPath)) for i := range target.RunPath { got[i] = target.RunPath[i].agentName } want := []string{"outer-seq", "inner", "inner2"} if len(got) != len(want) { t.Fatalf("unexpected runPath len: got %d want %d: %+v", len(got), len(want), got) } for i := range want { if got[i] != want[i] { t.Fatalf("runPath mismatch at %d: got %s want %s; full: %+v", i, got[i], want[i], got) } } for _, w := range outerRunCtx.Session.getEvents() { if w.AgentName != "outer-seq" { t.Fatalf("outer session contains non-outer event: %s", w.AgentName) } } if innerSpy.getCaptured() == nil { t.Fatalf("inner spy did not capture session") } for _, w := range innerSpy.getCaptured().getEvents() { if w.AgentName != "inner" { t.Fatalf("inner session contains non-inner event: %s", w.AgentName) } } if inner2Spy.getCaptured() == nil { t.Fatalf("inner2 spy did not capture session") } for _, w := range inner2Spy.getCaptured().getEvents() { if w.AgentName != "inner2" { t.Fatalf("inner2 session contains non-inner2 event: %s", w.AgentName) } } } func TestRunPathGating_IgnoresInnerExitAndAllowsOutput(t *testing.T) { ctx := context.Background() innerExit := &AgentEvent{Action: &AgentAction{Exit: true}, RunPath: []RunStep{{agentName: "inner"}}} finalOut := EventFromMessage(schema.AssistantMessage("ok", nil), nil, schema.Assistant, "") sub := &emitEventsAgent{events: []*AgentEvent{innerExit, finalOut}} fa := toFlowAgent(ctx, sub) it := fa.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("q")}}) var sawFinal bool for { ev, ok := it.Next() if !ok { break } if ev.Output != nil && ev.Output.MessageOutput != nil && !ev.Output.MessageOutput.IsStreaming { if ev.Output.MessageOutput.Message != nil && ev.Output.MessageOutput.Message.Content == "ok" { sawFinal = true } } } if !sawFinal { t.Fatalf("final output not observed; parent may have exited on inner Exit action") } } func TestRunPathGating_IgnoresInnerTransfer(t *testing.T) { ctx := context.Background() innerTransfer := &AgentEvent{Action: NewTransferToAgentAction("ghost"), RunPath: []RunStep{{agentName: "inner"}}} finalOut := EventFromMessage(schema.AssistantMessage("done", nil), nil, schema.Assistant, "") sub := &emitEventsAgent{events: []*AgentEvent{innerTransfer, finalOut}} fa := toFlowAgent(ctx, sub) it := fa.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("q")}}) var outputs int for { ev, ok := it.Next() if !ok { break } if ev.Output != nil && ev.Output.MessageOutput != nil && !ev.Output.MessageOutput.IsStreaming { if ev.Output.MessageOutput.Message != nil { outputs++ } } } if outputs == 0 { t.Fatalf("no outputs observed; parent may have transferred on inner transfer action") } } type streamAgent struct{} func (s *streamAgent) Name(context.Context) string { return "s" } func (s *streamAgent) Description(context.Context) string { return "s" } func (s *streamAgent) Run(context.Context, *AgentInput, ...AgentRunOption) *AsyncIterator[*AgentEvent] { it, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { frames := []*schema.Message{ schema.AssistantMessage("hello ", nil), schema.AssistantMessage("world", nil), } stream := schema.StreamReaderFromArray(frames) gen.Send(EventFromMessage(nil, stream, schema.Assistant, "")) gen.Close() }() return it } func TestInvokableAgentTool_InfoAndRun(t *testing.T) { ctx := context.Background() at := NewAgentTool(ctx, &streamAgent{}) info, err := at.Info(ctx) assert.NoError(t, err) assert.Equal(t, "s", info.Name) assert.Equal(t, "s", info.Desc) js, err := info.ParamsOneOf.ToJSONSchema() assert.NoError(t, err) found := false for _, r := range js.Required { if r == "request" { found = true break } } assert.True(t, found) prop, ok := js.Properties.Get("request") assert.True(t, ok) assert.Equal(t, string(schema.String), prop.Type) custom := schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "x": {Desc: "arg", Required: true, Type: schema.String}, }) at2 := NewAgentTool(ctx, &streamAgent{}, WithAgentInputSchema(custom)) info2, err := at2.Info(ctx) assert.NoError(t, err) assert.Equal(t, custom, info2.ParamsOneOf) out, err := at.(tool.InvokableTool).InvokableRun(ctx, `{"request":"x"}`) assert.NoError(t, err) assert.Equal(t, "hello world", out) } type emptyAgent struct{} func (e *emptyAgent) Name(context.Context) string { return "empty" } func (e *emptyAgent) Description(context.Context) string { return "empty" } func (e *emptyAgent) Run(context.Context, *AgentInput, ...AgentRunOption) *AsyncIterator[*AgentEvent] { it, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { gen.Close() }() return it } type noOutputAgent struct{} func (n *noOutputAgent) Name(context.Context) string { return "no" } func (n *noOutputAgent) Description(context.Context) string { return "no" } func (n *noOutputAgent) Run(context.Context, *AgentInput, ...AgentRunOption) *AsyncIterator[*AgentEvent] { it, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { gen.Send(&AgentEvent{}); gen.Close() }() return it } func TestInvokableAgentTool_ErrorCases(t *testing.T) { ctx := context.Background() atEmpty := NewAgentTool(ctx, &emptyAgent{}) out, err := atEmpty.(tool.InvokableTool).InvokableRun(ctx, `{"request":"x"}`) assert.Equal(t, "", out) assert.Error(t, err) atNo := NewAgentTool(ctx, &noOutputAgent{}) out2, err := atNo.(tool.InvokableTool).InvokableRun(ctx, `{"request":"x"}`) assert.NoError(t, err) assert.Equal(t, "", out2) } ================================================ FILE: adk/call_option.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import "github.com/cloudwego/eino/callbacks" type options struct { sharedParentSession bool sessionValues map[string]any checkPointID *string skipTransferMessages bool handlers []callbacks.Handler } // AgentRunOption is the call option for adk Agent. type AgentRunOption struct { implSpecificOptFn any // specify which Agent can see this AgentRunOption, if empty, all Agents can see this AgentRunOption agentNames []string } func (o AgentRunOption) DesignateAgent(name ...string) AgentRunOption { o.agentNames = append(o.agentNames, name...) return o } func getCommonOptions(base *options, opts ...AgentRunOption) *options { if base == nil { base = &options{} } return GetImplSpecificOptions(base, opts...) } // WithSessionValues sets session-scoped values for the agent run. func WithSessionValues(v map[string]any) AgentRunOption { return WrapImplSpecificOptFn(func(o *options) { o.sessionValues = v }) } // WithSkipTransferMessages disables forwarding transfer messages during execution. func WithSkipTransferMessages() AgentRunOption { return WrapImplSpecificOptFn(func(t *options) { t.skipTransferMessages = true }) } func withSharedParentSession() AgentRunOption { return WrapImplSpecificOptFn(func(o *options) { o.sharedParentSession = true }) } // WithCallbacks adds callback handlers to receive agent lifecycle events. // Handlers receive OnStart with AgentCallbackInput and OnEnd with AgentCallbackOutput. // Multiple handlers can be added; each receives an independent copy of the event stream. func WithCallbacks(handlers ...callbacks.Handler) AgentRunOption { return WrapImplSpecificOptFn(func(o *options) { o.handlers = append(o.handlers, handlers...) }) } // WrapImplSpecificOptFn is the option to wrap the implementation specific option function. func WrapImplSpecificOptFn[T any](optFn func(*T)) AgentRunOption { return AgentRunOption{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions extract the implementation specific options from AgentRunOption list, optionally providing a base options with default values. // e.g. // // myOption := &MyOption{ // Field1: "default_value", // } // // myOption := model.GetImplSpecificOptions(myOption, opts...) func GetImplSpecificOptions[T any](base *T, opts ...AgentRunOption) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { optFn, ok := opt.implSpecificOptFn.(func(*T)) if ok { optFn(base) } } } return base } // filterCallbackHandlersForNestedAgents removes callback handlers that have already been applied // to the current agent before passing opts to nested inner agents. // // This is necessary for workflow agents (LoopAgent, SequentialAgent, ParallelAgent) because: // 1. Callback handlers designated for the current agent are applied via initAgentCallbacks(), // which stores them in the context. // 2. Nested inner agents inherit this context, so they automatically receive these callbacks. // 3. If we also pass these handlers in opts to inner agents, they would be applied twice, // causing duplicate callback invocations. // // Note: This only applies to workflow agents where inner agents inherit context from the parent. // For flowAgent's sub-agents (which are peer agents that transfer to each other), the full opts // are passed since they don't inherit the parent's callback context. func filterCallbackHandlersForNestedAgents(currentAgentName string, opts []AgentRunOption) []AgentRunOption { if len(opts) == 0 { return nil } var filteredOpts []AgentRunOption for i := range opts { opt := opts[i] if opt.implSpecificOptFn == nil { filteredOpts = append(filteredOpts, opt) continue } if _, isCallbackOpt := opt.implSpecificOptFn.(func(*options)); isCallbackOpt { testOpt := &options{} opt.implSpecificOptFn.(func(*options))(testOpt) if len(testOpt.handlers) > 0 { if len(opt.agentNames) == 0 { continue } matched := false for _, name := range opt.agentNames { if name == currentAgentName { matched = true break } } if matched { continue } } } filteredOpts = append(filteredOpts, opt) } return filteredOpts } func filterOptions(agentName string, opts []AgentRunOption) []AgentRunOption { if len(opts) == 0 { return nil } var filteredOpts []AgentRunOption for i := range opts { opt := opts[i] if len(opt.agentNames) == 0 { filteredOpts = append(filteredOpts, opt) continue } for j := range opt.agentNames { if opt.agentNames[j] == agentName { filteredOpts = append(filteredOpts, opt) break } } } return filteredOpts } ================================================ FILE: adk/call_option_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" ) type mockAgentForOption struct { opts []AgentRunOption options *options } func (m *mockAgentForOption) Name(ctx context.Context) string { return "agent_1" } func (m *mockAgentForOption) Description(ctx context.Context) string { return "" } func (m *mockAgentForOption) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { m.opts = opts m.options = getCommonOptions(&options{}, opts...) return nil } ================================================ FILE: adk/callback.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" icb "github.com/cloudwego/eino/internal/callbacks" ) // AgentCallbackInput represents the input passed to agent callbacks during OnStart. // Use ConvAgentCallbackInput to safely convert from callbacks.CallbackInput. type AgentCallbackInput struct { // Input contains the agent input for a new run. Nil when resuming. Input *AgentInput // ResumeInfo contains resume information when resuming from an interrupt. Nil for new runs. ResumeInfo *ResumeInfo } // AgentCallbackOutput represents the output passed to agent callbacks during OnEnd. // Use ConvAgentCallbackOutput to safely convert from callbacks.CallbackOutput. // // Important: The Events iterator should be consumed asynchronously to avoid blocking // the agent execution. Each callback handler receives an independent copy of the iterator. type AgentCallbackOutput struct { // Events provides a stream of agent events. Each handler receives its own copy. Events *AsyncIterator[*AgentEvent] } func copyEventIterator(iter *AsyncIterator[*AgentEvent], n int) []*AsyncIterator[*AgentEvent] { if n <= 0 { return nil } if n == 1 { return []*AsyncIterator[*AgentEvent]{iter} } iterators := make([]*AsyncIterator[*AgentEvent], n) generators := make([]*AsyncGenerator[*AgentEvent], n) for i := 0; i < n; i++ { iterators[i], generators[i] = NewAsyncIteratorPair[*AgentEvent]() } go func() { defer func() { for _, g := range generators { g.Close() } }() for { event, ok := iter.Next() if !ok { break } for i := 0; i < n-1; i++ { generators[i].Send(copyAgentEvent(event)) } generators[n-1].Send(event) } }() return iterators } func copyAgentCallbackOutput(out *AgentCallbackOutput, n int) []*AgentCallbackOutput { if out == nil || out.Events == nil { result := make([]*AgentCallbackOutput, n) for i := 0; i < n; i++ { result[i] = out } return result } iters := copyEventIterator(out.Events, n) result := make([]*AgentCallbackOutput, n) for i, iter := range iters { result[i] = &AgentCallbackOutput{Events: iter} } return result } // ConvAgentCallbackInput converts a generic CallbackInput to AgentCallbackInput. // Returns nil if the input is not an AgentCallbackInput. func ConvAgentCallbackInput(input callbacks.CallbackInput) *AgentCallbackInput { if v, ok := input.(*AgentCallbackInput); ok { return v } return nil } // ConvAgentCallbackOutput converts a generic CallbackOutput to AgentCallbackOutput. // Returns nil if the output is not an AgentCallbackOutput. func ConvAgentCallbackOutput(output callbacks.CallbackOutput) *AgentCallbackOutput { if v, ok := output.(*AgentCallbackOutput); ok { return v } return nil } func initAgentCallbacks(ctx context.Context, agentName, agentType string, opts ...AgentRunOption) context.Context { ri := &callbacks.RunInfo{ Name: agentName, Type: agentType, Component: ComponentOfAgent, } o := getCommonOptions(nil, opts...) if len(o.handlers) == 0 { return icb.ReuseHandlers(ctx, ri) } return icb.AppendHandlers(ctx, ri, o.handlers...) } func getAgentType(agent Agent) string { if typer, ok := agent.(components.Typer); ok { return typer.GetType() } return "" } ================================================ FILE: adk/callback_integration_test.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/callbacks" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) type callbackRecorder struct { mu sync.Mutex onStartCalled bool onEndCalled bool runInfo *callbacks.RunInfo inputReceived *AgentCallbackInput eventsReceived []*AgentEvent eventsDone chan struct{} closeOnce sync.Once } func (r *callbackRecorder) getOnStartCalled() bool { r.mu.Lock() defer r.mu.Unlock() return r.onStartCalled } func (r *callbackRecorder) getOnEndCalled() bool { r.mu.Lock() defer r.mu.Unlock() return r.onEndCalled } func (r *callbackRecorder) getEventsReceived() []*AgentEvent { r.mu.Lock() defer r.mu.Unlock() result := make([]*AgentEvent, len(r.eventsReceived)) copy(result, r.eventsReceived) return result } func newRecordingHandler(recorder *callbackRecorder) callbacks.Handler { recorder.eventsDone = make(chan struct{}) return callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } recorder.mu.Lock() defer recorder.mu.Unlock() recorder.onStartCalled = true recorder.runInfo = info if agentInput := ConvAgentCallbackInput(input); agentInput != nil { recorder.inputReceived = agentInput } return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } recorder.mu.Lock() recorder.onEndCalled = true recorder.runInfo = info recorder.mu.Unlock() if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil { if agentOutput.Events != nil { go func() { defer recorder.closeOnce.Do(func() { close(recorder.eventsDone) }) for { event, ok := agentOutput.Events.Next() if !ok { break } recorder.mu.Lock() recorder.eventsReceived = append(recorder.eventsReceived, event) recorder.mu.Unlock() } }() return ctx } } recorder.closeOnce.Do(func() { close(recorder.eventsDone) }) return ctx }). Build() } func TestCallbackOnStartInvocation(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("test response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for callback", Instruction: "You are a test agent", Model: cm, }) assert.NoError(t, err) recorder := &callbackRecorder{} handler := newRecordingHandler(recorder) runner := NewRunner(ctx, RunnerConfig{Agent: agent}) iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } <-recorder.eventsDone assert.True(t, recorder.onStartCalled, "OnStart should be called") assert.NotNil(t, recorder.inputReceived, "Input should be received") assert.NotNil(t, recorder.inputReceived.Input, "AgentInput should be set") assert.Len(t, recorder.inputReceived.Input.Messages, 1) } func TestCallbackOnEndInvocation(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("test response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for callback", Instruction: "You are a test agent", Model: cm, }) assert.NoError(t, err) recorder := &callbackRecorder{} handler := newRecordingHandler(recorder) runner := NewRunner(ctx, RunnerConfig{Agent: agent}) iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } <-recorder.eventsDone assert.True(t, recorder.onEndCalled, "OnEnd should be called") assert.NotEmpty(t, recorder.eventsReceived, "Events should be received") } func TestCallbackRunInfoForChatModelAgent(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("test response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestChatAgent", Description: "Test chat agent", Instruction: "You are a test agent", Model: cm, }) assert.NoError(t, err) recorder := &callbackRecorder{} handler := newRecordingHandler(recorder) runner := NewRunner(ctx, RunnerConfig{Agent: agent}) iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } <-recorder.eventsDone assert.NotNil(t, recorder.runInfo) assert.Equal(t, "TestChatAgent", recorder.runInfo.Name) assert.Equal(t, "ChatModel", recorder.runInfo.Type) assert.Equal(t, ComponentOfAgent, recorder.runInfo.Component) } func TestMultipleCallbackHandlers(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("test response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Instruction: "You are a test agent", Model: cm, }) assert.NoError(t, err) recorder1 := &callbackRecorder{} recorder2 := &callbackRecorder{} handler1 := newRecordingHandler(recorder1) handler2 := newRecordingHandler(recorder2) runner := NewRunner(ctx, RunnerConfig{Agent: agent}) iter := runner.Query(ctx, "hello", WithCallbacks(handler1, handler2)) for { _, ok := iter.Next() if !ok { break } } <-recorder1.eventsDone <-recorder2.eventsDone assert.True(t, recorder1.onStartCalled, "Handler1 OnStart should be called") assert.True(t, recorder2.onStartCalled, "Handler2 OnStart should be called") assert.True(t, recorder1.onEndCalled, "Handler1 OnEnd should be called") assert.True(t, recorder2.onEndCalled, "Handler2 OnEnd should be called") assert.NotEmpty(t, recorder1.eventsReceived, "Handler1 should receive events") assert.NotEmpty(t, recorder2.eventsReceived, "Handler2 should receive events") } func TestCallbackWithWorkflowAgent(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response 1", nil), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response 2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "SequentialAgent", Description: "Sequential workflow", SubAgents: []Agent{agent1, agent2}, }) assert.NoError(t, err) var callbackInfos []*callbacks.RunInfo handler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component == ComponentOfAgent { callbackInfos = append(callbackInfos, info) } return ctx }). Build() runner := NewRunner(ctx, RunnerConfig{Agent: seqAgent}) iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } assert.NotEmpty(t, callbackInfos, "OnStart should be called for agents") foundAgent1 := false foundAgent2 := false for _, info := range callbackInfos { if info.Name == "Agent1" && info.Type == "ChatModel" { foundAgent1 = true } if info.Name == "Agent2" && info.Type == "ChatModel" { foundAgent2 = true } } assert.True(t, foundAgent1, "Agent1 callback should be invoked") assert.True(t, foundAgent2, "Agent2 callback should be invoked") } func TestCallbackEventsMatchAgentOutput(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() expectedContent := "This is the test response content" cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage(expectedContent, nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Instruction: "You are a test agent", Model: cm, }) assert.NoError(t, err) recorder := &callbackRecorder{} handler := newRecordingHandler(recorder) var agentEvents []*AgentEvent runner := NewRunner(ctx, RunnerConfig{Agent: agent}) iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { event, ok := iter.Next() if !ok { break } agentEvents = append(agentEvents, event) } <-recorder.eventsDone assert.NotEmpty(t, agentEvents, "Agent should emit events") assert.NotEmpty(t, recorder.eventsReceived, "Callback should receive events") foundExpectedContent := false for _, event := range recorder.eventsReceived { if event.Output != nil && event.Output.MessageOutput != nil { msg := event.Output.MessageOutput.Message if msg != nil && msg.Content == expectedContent { foundExpectedContent = true break } } } assert.True(t, foundExpectedContent, "Callback events should contain the expected content") } func TestCallbackOnEndForWorkflowAgent(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response 1", nil), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response 2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "SequentialAgent", Description: "Sequential workflow", SubAgents: []Agent{agent1, agent2}, }) assert.NoError(t, err) recorder := &callbackRecorder{} handler := newRecordingHandler(recorder) runner := NewRunner(ctx, RunnerConfig{Agent: seqAgent}) iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } <-recorder.eventsDone assert.True(t, recorder.getOnStartCalled(), "OnStart should be called for workflow agent") assert.True(t, recorder.getOnEndCalled(), "OnEnd should be called for workflow agent") assert.NotEmpty(t, recorder.getEventsReceived(), "Events should be received for workflow agent") } type ctxKeyForTest string const testOnStartMarkerKey ctxKeyForTest = "onStartMarker" func TestSubAgentContextIsolation(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("transferring to Agent2", []schema.ToolCall{ { ID: "transfer_1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "Agent2"}`, }, }, }), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response from Agent2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent that transfers to Agent2", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) agentWithSubAgents, err := SetSubAgents(ctx, agent1, []Agent{agent2}) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{Agent: agentWithSubAgents}) var mu sync.Mutex onStartContextMarkers := make(map[string][]string) handler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } mu.Lock() marker, _ := ctx.Value(testOnStartMarkerKey).(string) onStartContextMarkers[info.Name] = append(onStartContextMarkers[info.Name], marker) mu.Unlock() return context.WithValue(ctx, testOnStartMarkerKey, info.Name+"_marker") }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } mu.Lock() defer mu.Unlock() assert.NotEmpty(t, onStartContextMarkers["Agent1"], "Agent1's OnStart should be called") assert.NotEmpty(t, onStartContextMarkers["Agent2"], "Agent2's OnStart should be called") if len(onStartContextMarkers["Agent1"]) > 0 { assert.Equal(t, "", onStartContextMarkers["Agent1"][0], "Agent1's OnStart should receive context without marker (initial context)") } if len(onStartContextMarkers["Agent2"]) > 0 { assert.Equal(t, "", onStartContextMarkers["Agent2"][0], "Agent2's first OnStart should NOT inherit Agent1's marker - context should be isolated") } } func TestCallbackDesignatedToSpecificAgent(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("transferring to Agent2", []schema.ToolCall{ { ID: "transfer_1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "Agent2"}`, }, }, }), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response from Agent2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent that transfers to Agent2", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) agentWithSubAgents, err := SetSubAgents(ctx, agent1, []Agent{agent2}) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{Agent: agentWithSubAgents}) var mu sync.Mutex onStartCalls := make(map[string]int) agent2OnlyHandler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } mu.Lock() onStartCalls[info.Name]++ mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() iter := runner.Query(ctx, "hello", WithCallbacks(agent2OnlyHandler).DesignateAgent("Agent2")) for { _, ok := iter.Next() if !ok { break } } mu.Lock() defer mu.Unlock() assert.Equal(t, 0, onStartCalls["Agent1"], "Agent1's OnStart should NOT be called when handler is designated to Agent2") assert.Equal(t, 1, onStartCalls["Agent2"], "Agent2's OnStart should be called exactly once") } func TestCallbackDesignatedToMultipleAgents(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("transferring to Agent2", []schema.ToolCall{ { ID: "transfer_1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "Agent2"}`, }, }, }), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response from Agent2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) agentWithSubAgents, err := SetSubAgents(ctx, agent1, []Agent{agent2}) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{Agent: agentWithSubAgents}) var mu sync.Mutex onStartCalls := make(map[string]int) agent1And2Handler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } mu.Lock() onStartCalls[info.Name]++ mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() iter := runner.Query(ctx, "hello", WithCallbacks(agent1And2Handler).DesignateAgent("Agent1", "Agent2")) for { _, ok := iter.Next() if !ok { break } } mu.Lock() defer mu.Unlock() assert.Equal(t, 1, onStartCalls["Agent1"], "Agent1's OnStart should be called exactly once") assert.Equal(t, 1, onStartCalls["Agent2"], "Agent2's OnStart should be called exactly once") } func TestCallbackDesignatedExcludesNonMatchingAgents(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("transferring to Agent2", []schema.ToolCall{ { ID: "transfer_1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "Agent2"}`, }, }, }), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response from Agent2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) agentWithSubAgents, err := SetSubAgents(ctx, agent1, []Agent{agent2}) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{Agent: agentWithSubAgents}) var mu sync.Mutex onStartCalls := make(map[string]int) agent1OnlyHandler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } mu.Lock() onStartCalls[info.Name]++ mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() iter := runner.Query(ctx, "hello", WithCallbacks(agent1OnlyHandler).DesignateAgent("Agent1")) for { _, ok := iter.Next() if !ok { break } } mu.Lock() defer mu.Unlock() assert.Equal(t, 1, onStartCalls["Agent1"], "Agent1's OnStart should be called exactly once") assert.Equal(t, 0, onStartCalls["Agent2"], "Agent2's OnStart should NOT be called when handler is designated only to Agent1") } func TestMixedDesignatedAndGlobalCallbacks(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("transferring to Agent2", []schema.ToolCall{ { ID: "transfer_1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "Agent2"}`, }, }, }), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response from Agent2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent that transfers to Agent2", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) agentWithSubAgents, err := SetSubAgents(ctx, agent1, []Agent{agent2}) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{Agent: agentWithSubAgents}) var mu sync.Mutex globalHandlerCalls := make(map[string]int) agent2OnlyHandlerCalls := make(map[string]int) globalHandler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } mu.Lock() globalHandlerCalls[info.Name]++ mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() agent2OnlyHandler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } mu.Lock() agent2OnlyHandlerCalls[info.Name]++ mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() iter := runner.Query(ctx, "hello", WithCallbacks(globalHandler), WithCallbacks(agent2OnlyHandler).DesignateAgent("Agent2"), ) for { _, ok := iter.Next() if !ok { break } } mu.Lock() defer mu.Unlock() assert.Equal(t, 1, globalHandlerCalls["Agent1"], "Global handler should fire for Agent1") assert.Equal(t, 1, globalHandlerCalls["Agent2"], "Global handler should fire for Agent2") assert.Equal(t, 0, agent2OnlyHandlerCalls["Agent1"], "Agent2-only handler should NOT fire for Agent1") assert.Equal(t, 1, agent2OnlyHandlerCalls["Agent2"], "Agent2-only handler should fire for Agent2") } func TestOnStartCalledOncePerAgentWithDesignation(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm1 := mockModel.NewMockToolCallingChatModel(ctrl) cm1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("transferring to Agent2", []schema.ToolCall{ { ID: "transfer_1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "Agent2"}`, }, }, }), nil). Times(1) cm1.EXPECT().WithTools(gomock.Any()).Return(cm1, nil).AnyTimes() cm2 := mockModel.NewMockToolCallingChatModel(ctrl) cm2.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response from Agent2", nil), nil). Times(1) cm2.EXPECT().WithTools(gomock.Any()).Return(cm2, nil).AnyTimes() agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent1", Description: "First agent that transfers to Agent2", Instruction: "You are agent 1", Model: cm1, }) assert.NoError(t, err) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "Agent2", Description: "Second agent", Instruction: "You are agent 2", Model: cm2, }) assert.NoError(t, err) agentWithSubAgents, err := SetSubAgents(ctx, agent1, []Agent{agent2}) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{Agent: agentWithSubAgents}) var mu sync.Mutex onStartCalls := make(map[string]int) handler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != ComponentOfAgent { return ctx } mu.Lock() onStartCalls[info.Name]++ mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != ComponentOfAgent { return ctx } if agentOutput := ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() iter := runner.Query(ctx, "hello", WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } mu.Lock() defer mu.Unlock() assert.Equal(t, 1, onStartCalls["Agent1"], "Agent1's OnStart should be called exactly once") assert.Equal(t, 1, onStartCalls["Agent2"], "Agent2's OnStart should be called exactly once") } ================================================ FILE: adk/callback_test.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) func TestCopyEventIterator(t *testing.T) { t.Run("n=0 returns nil", func(t *testing.T) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { gen.Send(&AgentEvent{AgentName: "test"}) gen.Close() }() result := copyEventIterator(iter, 0) assert.Nil(t, result) }) t.Run("n=1 returns original iterator", func(t *testing.T) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { gen.Send(&AgentEvent{AgentName: "test"}) gen.Close() }() result := copyEventIterator(iter, 1) assert.Len(t, result, 1) assert.Equal(t, iter, result[0]) }) t.Run("n>1 creates n independent copies", func(t *testing.T) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() events := []*AgentEvent{ {AgentName: "agent1", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("msg1", nil)}}}, {AgentName: "agent2", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("msg2", nil)}}}, } go func() { for _, e := range events { gen.Send(e) } gen.Close() }() n := 3 copies := copyEventIterator(iter, n) assert.Len(t, copies, n) var wg sync.WaitGroup receivedEvents := make([][]*AgentEvent, n) for i := 0; i < n; i++ { wg.Add(1) go func(idx int) { defer wg.Done() for { event, ok := copies[idx].Next() if !ok { break } receivedEvents[idx] = append(receivedEvents[idx], event) } }(i) } wg.Wait() for i := 0; i < n; i++ { assert.Len(t, receivedEvents[i], len(events), "iterator %d should receive all events", i) for j, e := range receivedEvents[i] { assert.Equal(t, events[j].AgentName, e.AgentName) } } }) } func TestCopyAgentCallbackOutput(t *testing.T) { t.Run("nil output", func(t *testing.T) { result := copyAgentCallbackOutput(nil, 3) assert.Len(t, result, 3) for _, r := range result { assert.Nil(t, r) } }) t.Run("output with nil Events", func(t *testing.T) { out := &AgentCallbackOutput{Events: nil} result := copyAgentCallbackOutput(out, 3) assert.Len(t, result, 3) for _, r := range result { assert.Equal(t, out, r) } }) t.Run("valid output with events", func(t *testing.T) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { gen.Send(&AgentEvent{AgentName: "test"}) gen.Close() }() out := &AgentCallbackOutput{Events: iter} result := copyAgentCallbackOutput(out, 2) assert.Len(t, result, 2) for i, r := range result { assert.NotNil(t, r, "result[%d] should not be nil", i) assert.NotNil(t, r.Events, "result[%d].Events should not be nil", i) } }) } func TestConvAgentCallbackInput(t *testing.T) { t.Run("valid AgentCallbackInput", func(t *testing.T) { input := &AgentCallbackInput{ Input: &AgentInput{Messages: []Message{schema.UserMessage("test")}}, } result := ConvAgentCallbackInput(input) assert.Equal(t, input, result) }) t.Run("invalid type returns nil", func(t *testing.T) { result := ConvAgentCallbackInput("invalid") assert.Nil(t, result) }) t.Run("nil returns nil", func(t *testing.T) { result := ConvAgentCallbackInput(nil) assert.Nil(t, result) }) } func TestConvAgentCallbackOutput(t *testing.T) { t.Run("valid AgentCallbackOutput", func(t *testing.T) { iter, _ := NewAsyncIteratorPair[*AgentEvent]() output := &AgentCallbackOutput{Events: iter} result := ConvAgentCallbackOutput(output) assert.Equal(t, output, result) }) t.Run("invalid type returns nil", func(t *testing.T) { result := ConvAgentCallbackOutput("invalid") assert.Nil(t, result) }) t.Run("nil returns nil", func(t *testing.T) { result := ConvAgentCallbackOutput(nil) assert.Nil(t, result) }) } type mockTyperAgent struct { name string agentType string } func (a *mockTyperAgent) Name(_ context.Context) string { return a.name } func (a *mockTyperAgent) Description(_ context.Context) string { return "mock agent" } func (a *mockTyperAgent) GetType() string { return a.agentType } func (a *mockTyperAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() gen.Close() return iter } type mockNonTyperAgent struct { name string } func (a *mockNonTyperAgent) Name(_ context.Context) string { return a.name } func (a *mockNonTyperAgent) Description(_ context.Context) string { return "mock agent" } func (a *mockNonTyperAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() gen.Close() return iter } func TestGetAgentType(t *testing.T) { t.Run("agent implementing Typer", func(t *testing.T) { agent := &mockTyperAgent{name: "test", agentType: "CustomType"} result := getAgentType(agent) assert.Equal(t, "CustomType", result) }) t.Run("agent not implementing Typer", func(t *testing.T) { agent := &mockNonTyperAgent{name: "test"} result := getAgentType(agent) assert.Equal(t, "", result) }) } func TestWithCallbacksOption(t *testing.T) { handler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { return ctx }). Build() opt := WithCallbacks(handler) opts := getCommonOptions(nil, opt) assert.Len(t, opts.handlers, 1) } func TestWithMultipleCallbacksOption(t *testing.T) { handler1 := callbacks.NewHandlerBuilder().Build() handler2 := callbacks.NewHandlerBuilder().Build() opt := WithCallbacks(handler1, handler2) opts := getCommonOptions(nil, opt) assert.Len(t, opts.handlers, 2) } ================================================ FILE: adk/chatmodel.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "context" "encoding/gob" "errors" "fmt" "math" "runtime/debug" "sync" "sync/atomic" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] } func (e *chatModelAgentExecCtx) send(event *AgentEvent) { if e != nil && e.generator != nil { e.generator.Send(event) } } type chatModelAgentExecCtxKey struct{} func withChatModelAgentExecCtx(ctx context.Context, execCtx *chatModelAgentExecCtx) context.Context { return context.WithValue(ctx, chatModelAgentExecCtxKey{}, execCtx) } func getChatModelAgentExecCtx(ctx context.Context) *chatModelAgentExecCtx { if v := ctx.Value(chatModelAgentExecCtxKey{}); v != nil { return v.(*chatModelAgentExecCtx) } return nil } type chatModelAgentRunOptions struct { chatModelOptions []model.Option toolOptions []tool.Option agentToolOptions map[string][]AgentRunOption historyModifier func(context.Context, []Message) []Message } // WithChatModelOptions sets options for the underlying chat model. func WithChatModelOptions(opts []model.Option) AgentRunOption { return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { t.chatModelOptions = opts }) } // WithToolOptions sets options for tools used by the chat model agent. func WithToolOptions(opts []tool.Option) AgentRunOption { return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { t.toolOptions = opts }) } // WithAgentToolRunOptions specifies per-tool run options for the agent. func WithAgentToolRunOptions(opts map[string][]AgentRunOption) AgentRunOption { return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { t.agentToolOptions = opts }) } // WithHistoryModifier sets a function to modify history during resume. // Deprecated: use ResumeWithData and ChatModelAgentResumeData instead. func WithHistoryModifier(f func(context.Context, []Message) []Message) AgentRunOption { return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { t.historyModifier = f }) } type ToolsConfig struct { compose.ToolsNodeConfig // ReturnDirectly specifies tools that cause the agent to return immediately when called. // If multiple listed tools are called simultaneously, only the first one triggers the return. // The map keys are tool names indicate whether the tool should trigger immediate return. ReturnDirectly map[string]bool // EmitInternalEvents indicates whether internal events from agentTool should be emitted // to the parent agent's AsyncGenerator, allowing real-time streaming of nested agent output // to the end-user via Runner. // // Note that these forwarded events are NOT recorded in the parent agent's runSession. // They are only emitted to the end-user and have no effect on the parent agent's state // or checkpoint. // // Action Scoping: // Actions emitted by the inner agent are scoped to the agent tool boundary: // - Interrupted: Propagated via CompositeInterrupt to allow proper interrupt/resume // - Exit, TransferToAgent, BreakLoop: Ignored outside the agent tool EmitInternalEvents bool } // GenModelInput transforms agent instructions and input into a format suitable for the model. type GenModelInput func(ctx context.Context, instruction string, input *AgentInput) ([]Message, error) func defaultGenModelInput(ctx context.Context, instruction string, input *AgentInput) ([]Message, error) { msgs := make([]Message, 0, len(input.Messages)+1) if instruction != "" { sp := schema.SystemMessage(instruction) vs := GetSessionValues(ctx) if len(vs) > 0 { ct := prompt.FromMessages(schema.FString, sp) ms, err := ct.Format(ctx, vs) if err != nil { return nil, fmt.Errorf("defaultGenModelInput: failed to format instruction using FString template. "+ "This formatting is triggered automatically when SessionValues are present. "+ "If your instruction contains literal curly braces (e.g., JSON), provide a custom GenModelInput that uses another format. If you are using "+ "SessionValues for purposes other than instruction formatting, provide a custom GenModelInput that does no formatting at all: %w", err) } sp = ms[0] } msgs = append(msgs, sp) } msgs = append(msgs, input.Messages...) return msgs, nil } // ChatModelAgentState represents the state of a chat model agent during conversation. // This is the primary state type for both ChatModelAgentMiddleware and AgentMiddleware callbacks. type ChatModelAgentState struct { // Messages contains all messages in the current conversation session. Messages []Message } // AgentMiddleware provides hooks to customize agent behavior at various stages of execution. // // Limitations of AgentMiddleware (struct-based): // - Struct types are closed: users cannot add new methods // - Callbacks only return error, cannot return modified context // - Configuration is scattered across closures when using factory functions // // For new code requiring extensibility, consider using ChatModelAgentMiddleware (interface-based) instead. // AgentMiddleware is kept for backward compatibility and remains suitable for simple, // static additions like extra instruction or tools. // // See ChatModelAgentMiddleware documentation for detailed comparison. type AgentMiddleware struct { // AdditionalInstruction adds supplementary text to the agent's system instruction. // This instruction is concatenated with the base instruction before each chat model call. AdditionalInstruction string // AdditionalTools adds supplementary tools to the agent's available toolset. // These tools are combined with the tools configured for the agent. AdditionalTools []tool.BaseTool // BeforeChatModel is called before each ChatModel invocation, allowing modification of the agent state. BeforeChatModel func(context.Context, *ChatModelAgentState) error // AfterChatModel is called after each ChatModel invocation, allowing modification of the agent state. AfterChatModel func(context.Context, *ChatModelAgentState) error // WrapToolCall wraps tool calls with custom middleware logic. // Each middleware contains Invokable and/or Streamable functions for tool calls. WrapToolCall compose.ToolMiddleware } type ChatModelAgentConfig struct { // Name of the agent. Better be unique across all agents. Name string // Description of the agent's capabilities. // Helps other agents determine whether to transfer tasks to this agent. Description string // Instruction used as the system prompt for this agent. // Optional. If empty, no system prompt will be used. // Supports f-string placeholders for session values in default GenModelInput, for example: // "You are a helpful assistant. The current time is {Time}. The current user is {User}." // These placeholders will be replaced with session values for "Time" and "User". Instruction string // Model is the chat model used by the agent. // If your ChatModelAgent uses any tools, this model must support the model.WithTools // call option, as that's how ChatModelAgent configures the model with tool information. Model model.BaseChatModel ToolsConfig ToolsConfig // GenModelInput transforms instructions and input messages into the model's input format. // Optional. Defaults to defaultGenModelInput which combines instruction and messages. GenModelInput GenModelInput // Exit defines the tool used to terminate the agent process. // Optional. If nil, no Exit Action will be generated. // You can use the provided 'ExitTool' implementation directly. Exit tool.BaseTool // OutputKey stores the agent's response in the session. // Optional. When set, stores output via AddSessionValue(ctx, outputKey, msg.Content). OutputKey string // MaxIterations defines the upper limit of ChatModel generation cycles. // The agent will terminate with an error if this limit is exceeded. // Optional. Defaults to 20. MaxIterations int // Middlewares configures agent middleware for extending functionality. // Use for simple, static additions like extra instruction or tools. // Kept for backward compatibility; for new code, consider using Handlers instead. Middlewares []AgentMiddleware // Handlers configures interface-based handlers for extending agent behavior. // Unlike Middlewares (struct-based), Handlers allow users to: // - Add custom methods to their handler implementations // - Return modified context from handler methods // - Centralize configuration in struct fields instead of closures // // Handlers are processed after Middlewares, in registration order. // See ChatModelAgentMiddleware documentation for when to use Handlers vs Middlewares. // // Execution Order (relative to AgentMiddleware and ToolsConfig): // // Model call lifecycle (outermost to innermost wrapper chain): // 1. AgentMiddleware.BeforeChatModel (hook, runs before model call) // 2. ChatModelAgentMiddleware.BeforeModelRewriteState (hook, can modify state before model call) // 3. retryModelWrapper (internal - retries on failure, if configured) // 4. eventSenderModelWrapper (internal - sends model response events) // 5. ChatModelAgentMiddleware.WrapModel (wrapper, first registered is outermost) // 6. callbackInjectionModelWrapper (internal - injects callbacks if not enabled) // 7. Model.Generate/Stream // 8. ChatModelAgentMiddleware.AfterModelRewriteState (hook, can modify state after model call) // 9. AgentMiddleware.AfterChatModel (hook, runs after model call) // // Custom Event Sender Position: // By default, events are sent after all user middlewares (WrapModel) have processed the output, // containing the modified messages. To send events with original (unmodified) output, pass // NewEventSenderModelWrapper as a Handler after the modifying middleware: // // agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ // Handlers: []adk.ChatModelAgentMiddleware{ // myCustomHandler, // First registered = outermost wrapper // adk.NewEventSenderModelWrapper(), // Last registered = innermost, events sent with original output // }, // }) // // Handler order: first registered is outermost. So [A, B, C] becomes A(B(C(model))). // EventSenderModelWrapper sends events in post-processing, so placing it innermost // means it receives the original model output before outer handlers modify it. // // When EventSenderModelWrapper is detected in Handlers, the framework skips // the default event sender to avoid duplicate events. // // Tool call lifecycle (outermost to innermost): // 1. eventSenderToolHandler (internal ToolMiddleware - sends tool result events after all processing) // 2. ToolsConfig.ToolCallMiddlewares (ToolMiddleware) // 3. AgentMiddleware.WrapToolCall (ToolMiddleware) // 4. ChatModelAgentMiddleware.WrapToolCall (wrapper, first registered is outermost) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) // 6. Tool.InvokableRun/StreamableRun // // Tool List Modification: // // There are two ways to modify the tool list: // // 1. In BeforeAgent: Modify ChatModelAgentContext.Tools ([]tool.BaseTool) directly. This affects // both the tool info list passed to ChatModel AND the actual tools available for // execution. Changes persist for the entire agent run. // // 2. In WrapModel: Create a model wrapper that modifies the tool info list per model // request using model.WithTools(toolInfos). This ONLY affects the tool info list // passed to ChatModel, NOT the actual tools available for execution. Use this for // dynamic tool filtering/selection based on conversation context. The modification // is scoped to this model request only. Handlers []ChatModelAgentMiddleware // ModelRetryConfig configures retry behavior for the ChatModel. // When set, the agent will automatically retry failed ChatModel calls // based on the configured policy. // Optional. If nil, no retry will be performed. ModelRetryConfig *ModelRetryConfig } type ChatModelAgent struct { name string description string instruction string model model.BaseChatModel toolsConfig ToolsConfig genModelInput GenModelInput outputKey string maxIterations int subAgents []Agent parentAgent Agent disallowTransferToParent bool exit tool.BaseTool handlers []ChatModelAgentMiddleware middlewares []AgentMiddleware modelRetryConfig *ModelRetryConfig once sync.Once run runFunc frozen uint32 exeCtx *execContext } type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, instruction string, returnDirectly map[string]bool, opts ...compose.Option) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { if config.Name == "" { return nil, errors.New("agent 'Name' is required") } if config.Description == "" { return nil, errors.New("agent 'Description' is required") } if config.Model == nil { return nil, errors.New("agent 'Model' is required") } genInput := defaultGenModelInput if config.GenModelInput != nil { genInput = config.GenModelInput } tc := config.ToolsConfig // Tool call middleware execution order (outermost to innermost): // 1. eventSenderToolHandler (internal - sends tool result events after all modifications) // 2. User-provided ToolsConfig.ToolCallMiddlewares (original order preserved) // 3. Middlewares' WrapToolCall (in registration order) // 4. ChatModelAgentMiddleware.WrapToolCall (in registration order) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) eventSender := &eventSenderToolHandler{} tc.ToolCallMiddlewares = append( []compose.ToolMiddleware{{Invokable: eventSender.WrapInvokableToolCall, Streamable: eventSender.WrapStreamableToolCall, EnhancedInvokable: eventSender.WrapEnhancedInvokableToolCall, EnhancedStreamable: eventSender.WrapEnhancedStreamableToolCall, }}, tc.ToolCallMiddlewares..., ) tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, collectToolMiddlewaresFromMiddlewares(config.Middlewares)...) return &ChatModelAgent{ name: config.Name, description: config.Description, instruction: config.Instruction, model: config.Model, toolsConfig: tc, genModelInput: genInput, exit: config.Exit, outputKey: config.OutputKey, maxIterations: config.MaxIterations, handlers: config.Handlers, middlewares: config.Middlewares, modelRetryConfig: config.ModelRetryConfig, }, nil } func collectToolMiddlewaresFromMiddlewares(mws []AgentMiddleware) []compose.ToolMiddleware { var middlewares []compose.ToolMiddleware for _, m := range mws { if m.WrapToolCall.Invokable == nil && m.WrapToolCall.Streamable == nil { continue } middlewares = append(middlewares, m.WrapToolCall) } return middlewares } const ( TransferToAgentToolName = "transfer_to_agent" TransferToAgentToolDesc = "Transfer the question to another agent." TransferToAgentToolDescChinese = "将问题移交给其他 Agent。" ) var ( toolInfoTransferToAgent = &schema.ToolInfo{ Name: TransferToAgentToolName, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "agent_name": { Desc: "the name of the agent to transfer to", Required: true, Type: schema.String, }, }), } ToolInfoExit = &schema.ToolInfo{ Name: "exit", Desc: "Exit the agent process and return the final result.", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "final_result": { Desc: "the final result to return", Required: true, Type: schema.String, }, }), } ) type ExitTool struct{} func (et ExitTool) Info(_ context.Context) (*schema.ToolInfo, error) { return ToolInfoExit, nil } func (et ExitTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { type exitParams struct { FinalResult string `json:"final_result"` } params := &exitParams{} err := sonic.UnmarshalString(argumentsInJSON, params) if err != nil { return "", err } err = SendToolGenAction(ctx, "exit", NewExitAction()) if err != nil { return "", err } return params.FinalResult, nil } type transferToAgent struct{} func (tta transferToAgent) Info(_ context.Context) (*schema.ToolInfo, error) { desc := internal.SelectPrompt(internal.I18nPrompts{ English: TransferToAgentToolDesc, Chinese: TransferToAgentToolDescChinese, }) info := *toolInfoTransferToAgent info.Desc = desc return &info, nil } func transferToAgentToolOutput(destName string) string { tpl := internal.SelectPrompt(internal.I18nPrompts{ English: "successfully transferred to agent [%s]", Chinese: "成功移交任务至 agent [%s]", }) return fmt.Sprintf(tpl, destName) } func (tta transferToAgent) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { type transferParams struct { AgentName string `json:"agent_name"` } params := &transferParams{} err := sonic.UnmarshalString(argumentsInJSON, params) if err != nil { return "", err } err = SendToolGenAction(ctx, TransferToAgentToolName, NewTransferToAgentAction(params.AgentName)) if err != nil { return "", err } return transferToAgentToolOutput(params.AgentName), nil } func (a *ChatModelAgent) Name(_ context.Context) string { return a.name } func (a *ChatModelAgent) Description(_ context.Context) string { return a.description } func (a *ChatModelAgent) GetType() string { return "ChatModel" } func (a *ChatModelAgent) OnSetSubAgents(_ context.Context, subAgents []Agent) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") } if len(a.subAgents) > 0 { return errors.New("agent's sub-agents has already been set") } a.subAgents = subAgents return nil } func (a *ChatModelAgent) OnSetAsSubAgent(_ context.Context, parent Agent) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") } if a.parentAgent != nil { return errors.New("agent has already been set as a sub-agent of another agent") } a.parentAgent = parent return nil } func (a *ChatModelAgent) OnDisallowTransferToParent(_ context.Context) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") } a.disallowTransferToParent = true return nil } type ChatModelAgentInterruptInfo struct { Info *compose.InterruptInfo Data []byte } func init() { schema.RegisterName[*ChatModelAgentInterruptInfo]("_eino_adk_chat_model_agent_interrupt_info") } func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStream, outputKey string) error { if msg != nil { AddSessionValue(ctx, outputKey, msg.Content) return nil } concatenated, err := schema.ConcatMessageStream(msgStream) if err != nil { return err } AddSessionValue(ctx, outputKey, concatenated.Content) return nil } func errFunc(err error) runFunc { return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ ...compose.Option) { generator.Send(&AgentEvent{Err: err}) } } // ChatModelAgentResumeData holds data that can be provided to a ChatModelAgent during a resume operation // to modify its behavior. It is provided via the adk.ResumeWithData function. type ChatModelAgentResumeData struct { // HistoryModifier is a function that can transform the agent's message history before it is sent to the model. // This allows for adding new information or context upon resumption. HistoryModifier func(ctx context.Context, history []Message) []Message } type execContext struct { instruction string toolsNodeConf compose.ToolsNodeConfig returnDirectly map[string]bool toolInfos []*schema.ToolInfo unwrappedTools []tool.BaseTool rebuildGraph bool // whether needs to instantiate a new graph because of topology changes due to tool modifications toolUpdated bool // whether needs to pass a compose.WithToolList option to ToolsNode due to tool list change } func (a *ChatModelAgent) applyBeforeAgent(ctx context.Context, ec *execContext) (context.Context, *execContext, error) { runCtx := &ChatModelAgentContext{ Instruction: ec.instruction, Tools: cloneSlice(ec.unwrappedTools), ReturnDirectly: copyMap(ec.returnDirectly), } var err error for i, handler := range a.handlers { ctx, runCtx, err = handler.BeforeAgent(ctx, runCtx) if err != nil { return ctx, nil, fmt.Errorf("handler[%d] (%T) BeforeAgent failed: %w", i, handler, err) } } runtimeEC := &execContext{ instruction: runCtx.Instruction, toolsNodeConf: compose.ToolsNodeConfig{ Tools: runCtx.Tools, ToolCallMiddlewares: cloneSlice(ec.toolsNodeConf.ToolCallMiddlewares), }, returnDirectly: runCtx.ReturnDirectly, toolUpdated: true, rebuildGraph: (len(ec.toolsNodeConf.Tools) == 0 && len(runCtx.Tools) > 0) || (len(ec.returnDirectly) == 0 && len(runCtx.ReturnDirectly) > 0), } toolInfos, err := genToolInfos(ctx, &runtimeEC.toolsNodeConf) if err != nil { return ctx, nil, err } runtimeEC.toolInfos = toolInfos return ctx, runtimeEC, nil } func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, error) { instruction := a.instruction toolsNodeConf := compose.ToolsNodeConfig{ Tools: cloneSlice(a.toolsConfig.Tools), ToolCallMiddlewares: cloneSlice(a.toolsConfig.ToolCallMiddlewares), UnknownToolsHandler: a.toolsConfig.UnknownToolsHandler, ExecuteSequentially: a.toolsConfig.ExecuteSequentially, ToolArgumentsHandler: a.toolsConfig.ToolArgumentsHandler, } returnDirectly := copyMap(a.toolsConfig.ReturnDirectly) transferToAgents := a.subAgents if a.parentAgent != nil && !a.disallowTransferToParent { transferToAgents = append(transferToAgents, a.parentAgent) } if len(transferToAgents) > 0 { transferInstruction := genTransferToAgentInstruction(ctx, transferToAgents) instruction = concatInstructions(instruction, transferInstruction) toolsNodeConf.Tools = append(toolsNodeConf.Tools, &transferToAgent{}) returnDirectly[TransferToAgentToolName] = true } if a.exit != nil { toolsNodeConf.Tools = append(toolsNodeConf.Tools, a.exit) exitInfo, err := a.exit.Info(ctx) if err != nil { return nil, err } returnDirectly[exitInfo.Name] = true } for _, m := range a.middlewares { if m.AdditionalInstruction != "" { instruction = concatInstructions(instruction, m.AdditionalInstruction) } toolsNodeConf.Tools = append(toolsNodeConf.Tools, m.AdditionalTools...) } unwrappedTools := cloneSlice(toolsNodeConf.Tools) handlerMiddlewares := handlersToToolMiddlewares(a.handlers) toolsNodeConf.ToolCallMiddlewares = append(toolsNodeConf.ToolCallMiddlewares, handlerMiddlewares...) toolInfos, err := genToolInfos(ctx, &toolsNodeConf) if err != nil { return nil, err } return &execContext{ instruction: instruction, toolsNodeConf: toolsNodeConf, returnDirectly: returnDirectly, toolInfos: toolInfos, unwrappedTools: unwrappedTools, }, nil } func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ handlers: a.handlers, middlewares: a.middlewares, retryConfig: a.modelRetryConfig, }) type noToolsInput struct { input *AgentInput instruction string } return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, instruction string, _ map[string]bool, opts ...compose.Option) { chain := compose.NewChain[noToolsInput, Message]( compose.WithGenLocalState(func(ctx context.Context) (state *State) { return &State{} })). AppendLambda(compose.InvokableLambda(func(ctx context.Context, in noToolsInput) ([]Message, error) { messages, err := a.genModelInput(ctx, in.instruction, in.input) if err != nil { return nil, err } return messages, nil })). AppendChatModel(wrappedModel) r, err := chain.Compile(ctx, compose.WithGraphName(a.name), compose.WithCheckPointStore(store), compose.WithSerializer(&gobSerializer{})) if err != nil { generator.Send(&AgentEvent{Err: err}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ generator: generator, }) in := noToolsInput{input: input, instruction: instruction} var msg Message var msgStream MessageStream if input.EnableStreaming { msgStream, err = r.Stream(ctx, in, opts...) } else { msg, err = r.Invoke(ctx, in, opts...) } if err == nil { if a.outputKey != "" { err = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err != nil { generator.Send(&AgentEvent{Err: err}) } } else if msgStream != nil { msgStream.Close() } } else { generator.Send(&AgentEvent{Err: err}) } } } func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) (runFunc, error) { conf := &reactConfig{ model: a.model, toolsConfig: &bc.toolsNodeConf, modelWrapperConf: &modelWrapperConfig{ handlers: a.handlers, middlewares: a.middlewares, retryConfig: a.modelRetryConfig, toolInfos: bc.toolInfos, }, toolsReturnDirectly: bc.returnDirectly, agentName: a.name, maxIterations: a.maxIterations, } type reactRunInput struct { input *AgentInput instruction string } return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, instruction string, returnDirectly map[string]bool, opts ...compose.Option) { g, err := newReact(ctx, conf) if err != nil { generator.Send(&AgentEvent{Err: err}) return } chain := compose.NewChain[reactRunInput, Message](). AppendLambda( compose.InvokableLambda(func(ctx context.Context, in reactRunInput) (*reactInput, error) { messages, genErr := a.genModelInput(ctx, in.instruction, in.input) if genErr != nil { return nil, genErr } return &reactInput{ messages: messages, }, nil }), ). AppendGraph(g, compose.WithNodeName("ReAct"), compose.WithGraphCompileOptions(compose.WithMaxRunSteps(math.MaxInt))) var compileOptions []compose.GraphCompileOption compileOptions = append(compileOptions, compose.WithGraphName(a.name), compose.WithCheckPointStore(store), compose.WithSerializer(&gobSerializer{}), compose.WithMaxRunSteps(math.MaxInt)) runnable, err_ := chain.Compile(ctx, compileOptions...) if err_ != nil { generator.Send(&AgentEvent{Err: err_}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ runtimeReturnDirectly: returnDirectly, generator: generator, }) in := reactRunInput{ input: input, instruction: instruction, } var runOpts []compose.Option runOpts = append(runOpts, opts...) if a.toolsConfig.EmitInternalEvents { runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(generator)))) } if input.EnableStreaming { runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) } var msg Message var msgStream MessageStream if input.EnableStreaming { msgStream, err_ = runnable.Stream(ctx, in, runOpts...) } else { msg, err_ = runnable.Invoke(ctx, in, runOpts...) } if err_ == nil { if a.outputKey != "" { err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err_ != nil { generator.Send(&AgentEvent{Err: err_}) } } else if msgStream != nil { msgStream.Close() } return } info, ok := compose.ExtractInterruptInfo(err_) if !ok { generator.Send(&AgentEvent{Err: err_}) return } data, existed, err := store.Get(ctx, bridgeCheckpointID) if err != nil { generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", err)}) return } if !existed { generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) return } is := FromInterruptContexts(info.InterruptContexts) event := CompositeInterrupt(ctx, info, data, is) event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ Info: info, Data: data, } event.AgentName = a.name generator.Send(event) }, nil } func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { a.once.Do(func() { ec, err := a.prepareExecContext(ctx) if err != nil { a.run = errFunc(err) return } a.exeCtx = ec if len(ec.toolsNodeConf.Tools) == 0 { a.run = a.buildNoToolsRunFunc(ctx) return } run, err := a.buildReactRunFunc(ctx, ec) if err != nil { a.run = errFunc(err) return } a.run = run }) atomic.StoreUint32(&a.frozen, 1) return a.run } func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFunc, *execContext, error) { defaultRun := a.buildRunFunc(ctx) bc := a.exeCtx if bc == nil { return ctx, defaultRun, bc, nil } if len(a.handlers) == 0 { runtimeBC := &execContext{ instruction: bc.instruction, toolsNodeConf: bc.toolsNodeConf, returnDirectly: bc.returnDirectly, toolInfos: bc.toolInfos, } return ctx, defaultRun, runtimeBC, nil } ctx, runtimeBC, err := a.applyBeforeAgent(ctx, bc) if err != nil { return ctx, nil, nil, err } if !runtimeBC.rebuildGraph { return ctx, defaultRun, runtimeBC, nil } var tempRun runFunc if len(runtimeBC.toolsNodeConf.Tools) == 0 { tempRun = a.buildNoToolsRunFunc(ctx) } else { tempRun, err = a.buildReactRunFunc(ctx, runtimeBC) if err != nil { return ctx, nil, nil, err } } return ctx, tempRun, runtimeBC, nil } func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { generator.Send(&AgentEvent{Err: err}) generator.Close() }() return iterator } co := getComposeOptions(opts) co = append(co, compose.WithCheckPointID(bridgeCheckpointID)) if bc != nil { co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos))) if bc.toolUpdated { co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...))) } } go func() { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&AgentEvent{Err: e}) } generator.Close() }() var ( instruction string returnDirectly map[string]bool ) if bc != nil { instruction = bc.instruction returnDirectly = bc.returnDirectly } run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, co...) }() return iterator } func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { generator.Send(&AgentEvent{Err: err}) generator.Close() }() return iterator } co := getComposeOptions(opts) co = append(co, compose.WithCheckPointID(bridgeCheckpointID)) if bc != nil { co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos))) if bc.toolUpdated { co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...))) } } if info.InterruptState == nil { panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) } stateByte, ok := info.InterruptState.([]byte) if !ok { panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid interrupt state type: %T", a.Name(ctx), info.InterruptState)) } // Migrate legacy checkpoints before resume. // This covers both: // - v0.7.*: state is stored as a struct wire type (stateV07) under the legacy name. // - v0.8.0-v0.8.3: state is stored as a GobEncoder payload under the same legacy name and must // be routed to a GobDecode-compatible compat type via byte-patching. // The result is re-encoded so the resume path always operates on the current *State. stateByte, err = preprocessComposeCheckpoint(stateByte) if err != nil { go func() { generator.Send(&AgentEvent{Err: err}) generator.Close() }() return iterator } var historyModifier func(ctx context.Context, history []Message) []Message if info.ResumeData != nil { resumeData, ok := info.ResumeData.(*ChatModelAgentResumeData) if !ok { panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid resume data type: %T", a.Name(ctx), info.ResumeData)) } historyModifier = resumeData.HistoryModifier } if historyModifier != nil { co = append(co, compose.WithStateModifier(func(ctx context.Context, path compose.NodePath, state any) error { s, ok := state.(*State) if !ok { return nil } s.Messages = historyModifier(ctx, s.Messages) return nil })) } go func() { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&AgentEvent{Err: e}) } generator.Close() }() var ( instruction string returnDirectly map[string]bool ) if bc != nil { instruction = bc.instruction returnDirectly = bc.returnDirectly } run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, newResumeBridgeStore(stateByte), instruction, returnDirectly, co...) }() return iterator } func getComposeOptions(opts []AgentRunOption) []compose.Option { o := GetImplSpecificOptions[chatModelAgentRunOptions](nil, opts...) var co []compose.Option if len(o.chatModelOptions) > 0 { co = append(co, compose.WithChatModelOption(o.chatModelOptions...)) } var to []tool.Option if len(o.toolOptions) > 0 { to = append(to, o.toolOptions...) } for toolName, atos := range o.agentToolOptions { to = append(to, withAgentToolOptions(toolName, atos)) } if len(to) > 0 { co = append(co, compose.WithToolsNodeOption(compose.WithToolOption(to...))) } if o.historyModifier != nil { co = append(co, compose.WithStateModifier(func(ctx context.Context, path compose.NodePath, state any) error { s, ok := state.(*State) if !ok { return fmt.Errorf("unexpected state type: %T, expected: %T", state, &State{}) } s.Messages = o.historyModifier(ctx, s.Messages) return nil })) } return co } type gobSerializer struct{} func (g *gobSerializer) Marshal(v any) ([]byte, error) { buf := new(bytes.Buffer) err := gob.NewEncoder(buf).Encode(v) if err != nil { return nil, err } return buf.Bytes(), nil } func (g *gobSerializer) Unmarshal(data []byte, v any) error { buf := bytes.NewBuffer(data) return gob.NewDecoder(buf).Decode(v) } // preprocessComposeCheckpoint migrates legacy compose checkpoints to the current format. // It handles the v0.8.0-v0.8.3 format: // - gob name "_eino_adk_state_v080_" (already byte-patched by preprocessADKCheckpoint // from "_eino_adk_react_state"), opaque-bytes wire format → decoded as *stateV080 // // v0.7 checkpoints need no migration — State is now a plain struct registered under the // same gob name, and gob handles missing fields gracefully. // // Fast path: if the legacy name is not present, skip entirely. func preprocessComposeCheckpoint(data []byte) ([]byte, error) { const lenPrefixedCompatName = "\x15" + stateGobNameV080 if bytes.Contains(data, []byte(lenPrefixedCompatName)) { // v0.8.0-v0.8.3: already byte-patched by preprocessADKCheckpoint; decode as *stateV080. migrated, err := compose.MigrateCheckpointState(data, &gobSerializer{}, func(state any) (any, bool, error) { sc, ok := state.(*stateV080) if !ok { return state, false, nil } return stateV080ToState(sc), true, nil }) if err != nil { return nil, fmt.Errorf("failed to migrate v0.8.0-v0.8.3 compose checkpoint: %w", err) } return migrated, nil } return data, nil } ================================================ FILE: adk/chatmodel_retry_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "io" "strings" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) var errRetryAble = errors.New("retry-able error") var errNonRetryAble = errors.New("non-retry-able error") func TestChatModelAgentRetry_NoTools_DirectError_Generate(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) var callCount int32 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { count := atomic.AddInt32(&callCount, 1) if count < 3 { return nil, errRetryAble } return schema.AssistantMessage("Success after retry", nil), nil }).Times(3) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: cm, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.Equal(t, "Success after retry", event.Output.MessageOutput.Message.Content) _, ok = iterator.Next() assert.False(t, ok) assert.Equal(t, int32(3), atomic.LoadInt32(&callCount)) } func TestChatModelAgentRetry_NoTools_DirectError_Stream(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) var callCount int32 cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { count := atomic.AddInt32(&callCount, 1) if count < 2 { return nil, errRetryAble } return schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("Success", nil), }), nil }).Times(2) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: cm, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.True(t, event.Output.MessageOutput.IsStreaming) _, ok = iterator.Next() assert.False(t, ok) assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) } type streamErrorModel struct { callCount int32 failAtChunk int maxFailures int tools []*schema.ToolInfo returnTool bool } func (m *streamErrorModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { return schema.AssistantMessage("Generated", nil), nil } func (m *streamErrorModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { count := atomic.AddInt32(&m.callCount, 1) sr, sw := schema.Pipe[*schema.Message](10) go func() { defer sw.Close() for i := 0; i < 5; i++ { if i == m.failAtChunk && int(count) <= m.maxFailures { sw.Send(nil, errRetryAble) return } if m.returnTool && i == 0 { sw.Send(schema.AssistantMessage("", []schema.ToolCall{{ ID: "call-1", Function: schema.FunctionCall{Name: "test_tool", Arguments: "{}"}, }}), nil) } else { sw.Send(schema.AssistantMessage("chunk", nil), nil) } } }() return sr, nil } func (m *streamErrorModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { m.tools = tools return m, nil } func TestChatModelAgentRetry_StreamError(t *testing.T) { t.Run("WithTools", func(t *testing.T) { ctx := context.Background() m := &streamErrorModel{ failAtChunk: 2, maxFailures: 2, returnTool: false, } config := &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: m, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{&fakeToolForTest{tarCount: 0}}, }, }, } agent, err := NewChatModelAgent(ctx, config) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } iterator := agent.Run(ctx, input) var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 3, len(events)) var streamErrEventCount int var errs []error for i, event := range events { if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming { sr := event.Output.MessageOutput.MessageStream for { msg, err := sr.Recv() if err == io.EOF { break } if err != nil { streamErrEventCount++ errs = append(errs, err) t.Logf("event %d: err: %v", i, err) break } t.Logf("event %d: %v", i, msg.Content) } } } assert.Equal(t, 2, streamErrEventCount) assert.Equal(t, 2, len(errs)) var willRetryErr *WillRetryError assert.True(t, errors.As(errs[0], &willRetryErr)) assert.True(t, errors.As(errs[1], &willRetryErr)) assert.Equal(t, int32(3), atomic.LoadInt32(&m.callCount)) }) t.Run("NoTools", func(t *testing.T) { ctx := context.Background() m := &streamErrorModel{ failAtChunk: 2, maxFailures: 2, returnTool: false, } config := &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: m, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, } agent, err := NewChatModelAgent(ctx, config) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } iterator := agent.Run(ctx, input) var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 3, len(events)) var streamErrEventCount int var errs []error for i, event := range events { if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming { sr := event.Output.MessageOutput.MessageStream for { msg, err := sr.Recv() if err == io.EOF { break } if err != nil { streamErrEventCount++ errs = append(errs, err) t.Logf("event %d: err: %v", i, err) break } t.Logf("event %d: %v", i, msg.Content) } } } assert.Equal(t, 2, streamErrEventCount) assert.Equal(t, 2, len(errs)) var willRetryErr *WillRetryError assert.True(t, errors.As(errs[0], &willRetryErr)) assert.True(t, errors.As(errs[1], &willRetryErr)) assert.Equal(t, int32(3), atomic.LoadInt32(&m.callCount)) }) } func TestChatModelAgentRetry_WithTools_DirectError_Generate(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) var callCount int32 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { count := atomic.AddInt32(&callCount, 1) if count < 2 { return nil, errRetryAble } return schema.AssistantMessage("Success after retry", nil), nil }).Times(2) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() fakeTool := &fakeToolForTest{tarCount: 0} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, }, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.Equal(t, "Success after retry", event.Output.MessageOutput.Message.Content) _, ok = iterator.Next() assert.False(t, ok) assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) } func TestChatModelAgentRetry_NonRetryableError(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errNonRetryAble).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: cm, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.NotNil(t, event.Err) assert.True(t, errors.Is(event.Err, errNonRetryAble)) _, ok = iterator.Next() assert.False(t, ok) } type inputCapturingModel struct { capturedInputs [][]Message } func (m *inputCapturingModel) Generate(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { m.capturedInputs = append(m.capturedInputs, input) return schema.AssistantMessage("Response from capturing model", nil), nil } func (m *inputCapturingModel) Stream(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { m.capturedInputs = append(m.capturedInputs, input) return schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("Response from capturing model", nil), }), nil } func (m *inputCapturingModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { return m, nil } func TestChatModelAgentRetry_MaxRetriesExhausted(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errRetryAble).Times(4) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: cm, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.NotNil(t, event.Err) assert.True(t, errors.Is(event.Err, ErrExceedMaxRetries)) var retryErr *RetryExhaustedError assert.True(t, errors.As(event.Err, &retryErr)) assert.True(t, errors.Is(retryErr.LastErr, errRetryAble)) _, ok = iterator.Next() assert.False(t, ok) } func TestChatModelAgentRetry_BackoffFunction(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) var backoffCalls []int var callCount int32 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { count := atomic.AddInt32(&callCount, 1) if count < 3 { return nil, errRetryAble } return schema.AssistantMessage("Success", nil), nil }).Times(3) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: cm, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, BackoffFunc: func(ctx context.Context, attempt int) time.Duration { backoffCalls = append(backoffCalls, attempt) return time.Millisecond }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.Nil(t, event.Err) _, ok = iterator.Next() assert.False(t, ok) assert.Equal(t, []int{1, 2}, backoffCalls) } func TestChatModelAgentRetry_NoRetryConfig(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errRetryAble).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent without retry config", Instruction: "You are a helpful assistant.", Model: cm, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.NotNil(t, event.Err) assert.True(t, errors.Is(event.Err, errRetryAble)) _, ok = iterator.Next() assert.False(t, ok) } func TestChatModelAgentRetry_WithTools_NonRetryAbleStreamError(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errNonRetryAble).Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() fakeTool := &fakeToolForTest{tarCount: 0} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, }, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.NotNil(t, event.Err) assert.True(t, errors.Is(event.Err, errNonRetryAble)) _, ok = iterator.Next() assert.False(t, ok) } type nonRetryAbleStreamErrorModel struct { tools []*schema.ToolInfo } func (m *nonRetryAbleStreamErrorModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { return schema.AssistantMessage("Generated", nil), nil } func (m *nonRetryAbleStreamErrorModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("chunk1", nil), nil) sw.Send(nil, errNonRetryAble) }() return sr, nil } func (m *nonRetryAbleStreamErrorModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { m.tools = tools return m, nil } func TestChatModelAgentRetry_NoTools_NonRetryAbleStreamError(t *testing.T) { ctx := context.Background() m := &nonRetryAbleStreamErrorModel{} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent for retry functionality", Instruction: "You are a helpful assistant.", Model: m, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } iterator := agent.Run(ctx, input) var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 2, len(events)) event0 := events[0] assert.NotNil(t, event0.Output) assert.NotNil(t, event0.Output.MessageOutput) assert.True(t, event0.Output.MessageOutput.IsStreaming) sr := event0.Output.MessageOutput.MessageStream var streamErr error for { _, err := sr.Recv() if err == io.EOF { break } if err != nil { streamErr = err break } } assert.NotNil(t, streamErr) assert.True(t, errors.Is(streamErr, errNonRetryAble), "Stream error should be the original error") event1 := events[1] assert.NotNil(t, event1.Err) assert.True(t, errors.Is(event1.Err, errNonRetryAble)) } func TestDefaultBackoff(t *testing.T) { ctx := context.Background() d1 := defaultBackoff(ctx, 1) d2 := defaultBackoff(ctx, 2) d3 := defaultBackoff(ctx, 3) t.Logf("Backoff delays: d1=%v, d2=%v, d3=%v", d1, d2, d3) assert.True(t, d1 >= 100*time.Millisecond && d1 < 150*time.Millisecond, "First retry should be ~100ms + jitter (0-50ms), got %v", d1) assert.True(t, d2 >= 200*time.Millisecond && d2 < 300*time.Millisecond, "Second retry should be ~200ms + jitter (0-100ms), got %v", d2) assert.True(t, d3 >= 400*time.Millisecond && d3 < 600*time.Millisecond, "Third retry should be ~400ms + jitter (0-200ms), got %v", d3) d10 := defaultBackoff(ctx, 10) t.Logf("Backoff delay for attempt 10: %v", d10) assert.True(t, d10 >= 10*time.Second && d10 <= 15*time.Second, "Delay should be capped at 10s + jitter (0-5s), got %v", d10) d100 := defaultBackoff(ctx, 100) t.Logf("Backoff delay for attempt 100: %v", d100) assert.True(t, d100 >= 10*time.Second && d100 <= 15*time.Second, "Delay should still be capped at 10s + jitter for very high attempts, got %v", d100) } func TestRetryExhaustedError_ErrorString(t *testing.T) { errWithLast := &RetryExhaustedError{ LastErr: errors.New("connection timeout"), TotalRetries: 3, } assert.Contains(t, errWithLast.Error(), "exceeds max retries") assert.Contains(t, errWithLast.Error(), "connection timeout") errWithoutLast := &RetryExhaustedError{ LastErr: nil, TotalRetries: 3, } assert.Equal(t, "exceeds max retries", errWithoutLast.Error()) } func TestWillRetryError_ErrorString(t *testing.T) { willRetry := &WillRetryError{ErrStr: "transient error", RetryAttempt: 1} assert.Equal(t, "transient error", willRetry.Error()) } type customError struct { code int msg string } func (e *customError) Error() string { return e.msg } func TestWillRetryError_Unwrap(t *testing.T) { originalErr := &customError{code: 500, msg: "internal error"} willRetry := &WillRetryError{ErrStr: originalErr.Error(), RetryAttempt: 1, err: originalErr} assert.True(t, errors.Is(willRetry, originalErr)) var targetErr *customError assert.True(t, errors.As(willRetry, &targetErr)) assert.Equal(t, 500, targetErr.code) assert.Equal(t, "internal error", targetErr.msg) } func TestChatModelAgentRetry_DefaultIsRetryAble(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) var callCount int32 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { count := atomic.AddInt32(&callCount, 1) if count < 2 { return nil, errors.New("any error should be retried") } return schema.AssistantMessage("Success", nil), nil }).Times(2) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "RetryTestAgent", Description: "Test agent with default IsRetryAble", Instruction: "You are a helpful assistant.", Model: cm, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.Equal(t, "Success", event.Output.MessageOutput.Message.Content) _, ok = iterator.Next() assert.False(t, ok) assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) } func TestSequentialWorkflow_RetryAbleStreamError_SuccessfulRetry(t *testing.T) { ctx := context.Background() retryModel := &streamErrorModel{ failAtChunk: 2, maxFailures: 2, } agentA, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "AgentA", Description: "Agent A with retry that emits stream errors then succeeds", Instruction: "You are agent A.", Model: retryModel, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) capturingModel := &inputCapturingModel{} agentB, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "AgentB", Description: "Agent B that captures input", Instruction: "You are agent B.", Model: capturingModel, }) assert.NoError(t, err) sequentialAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "SequentialAgent", Description: "Sequential agent A->B", SubAgents: []Agent{agentA, agentB}, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } ctx, _ = initRunCtx(ctx, sequentialAgent.Name(ctx), input) iterator := sequentialAgent.Run(ctx, input) var events []*AgentEvent var willRetryErrCount int for { event, ok := iterator.Next() if !ok { break } events = append(events, event) if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming { sr := event.Output.MessageOutput.MessageStream for { _, err := sr.Recv() if err == io.EOF { break } if err != nil { var retryErr *WillRetryError if errors.As(err, &retryErr) { willRetryErrCount++ } break } } } } assert.Equal(t, 2, willRetryErrCount, "End-user should receive 2 WillRetryError events") assert.Equal(t, 1, len(capturingModel.capturedInputs), "Agent B should be called exactly once") successorInput := capturingModel.capturedInputs[0] var hasSuccessfulMessage bool for _, msg := range successorInput { if strings.Contains(msg.Content, "chunkchunkchunkchunkchunk") { hasSuccessfulMessage = true break } } assert.True(t, hasSuccessfulMessage, "Agent B should receive the successful message from Agent A") for _, msg := range successorInput { assert.NotContains(t, msg.Content, "retry-able error", "Agent B should not receive failed stream messages") } } type streamErrorModelNoRetry struct { callCount int32 } func (m *streamErrorModelNoRetry) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { return schema.AssistantMessage("Generated", nil), nil } func (m *streamErrorModelNoRetry) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { atomic.AddInt32(&m.callCount, 1) sr, sw := schema.Pipe[*schema.Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("chunk1", nil), nil) sw.Send(schema.AssistantMessage("chunk2", nil), nil) sw.Send(nil, errRetryAble) }() return sr, nil } func (m *streamErrorModelNoRetry) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { return m, nil } func TestSequentialWorkflow_NonRetryAbleStreamError_StopsFlow(t *testing.T) { ctx := context.Background() nonRetryModel := &nonRetryAbleStreamErrorModel{} agentA, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "AgentA", Description: "Agent A that emits non-retryable stream error", Instruction: "You are agent A.", Model: nonRetryModel, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) }, }, }) assert.NoError(t, err) capturingModel := &inputCapturingModel{} agentB, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "AgentB", Description: "Agent B that captures input", Instruction: "You are agent B.", Model: capturingModel, }) assert.NoError(t, err) sequentialAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "SequentialAgent", Description: "Sequential agent A->B", SubAgents: []Agent{agentA, agentB}, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } ctx, _ = initRunCtx(ctx, sequentialAgent.Name(ctx), input) iterator := sequentialAgent.Run(ctx, input) var events []*AgentEvent var streamErrFound bool var finalErrEvent *AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) if event.Err != nil { finalErrEvent = event } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming { sr := event.Output.MessageOutput.MessageStream for { _, err := sr.Recv() if err == io.EOF { break } if err != nil { streamErrFound = true assert.True(t, errors.Is(err, errNonRetryAble), "Stream error should be the original error") break } } } } assert.True(t, streamErrFound, "End-user should receive stream error") assert.NotNil(t, finalErrEvent, "Should receive a final error event") assert.True(t, errors.Is(finalErrEvent.Err, errNonRetryAble), "Final error should be the non-retryable error") assert.Equal(t, 0, len(capturingModel.capturedInputs), "Agent B should NOT be called due to error") } func TestSequentialWorkflow_NoRetryConfig_StreamError_StopsFlow(t *testing.T) { ctx := context.Background() noRetryModel := &streamErrorModelNoRetry{} agentA, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "AgentA", Description: "Agent A without retry config that emits stream error", Instruction: "You are agent A.", Model: noRetryModel, }) assert.NoError(t, err) capturingModel := &inputCapturingModel{} agentB, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "AgentB", Description: "Agent B that captures input", Instruction: "You are agent B.", Model: capturingModel, }) assert.NoError(t, err) sequentialAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "SequentialAgent", Description: "Sequential agent A->B", SubAgents: []Agent{agentA, agentB}, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, EnableStreaming: true, } ctx, _ = initRunCtx(ctx, sequentialAgent.Name(ctx), input) iterator := sequentialAgent.Run(ctx, input) var events []*AgentEvent var streamErrFound bool var finalErrEvent *AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) if event.Err != nil { finalErrEvent = event } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming { sr := event.Output.MessageOutput.MessageStream for { _, err := sr.Recv() if err == io.EOF { break } if err != nil { streamErrFound = true assert.True(t, errors.Is(err, errRetryAble), "Stream error should be the original error") break } } } } assert.True(t, streamErrFound, "End-user should receive stream error") assert.NotNil(t, finalErrEvent, "Should receive a final error event") assert.True(t, errors.Is(finalErrEvent.Err, errRetryAble), "Final error should be the original error") assert.Equal(t, 0, len(capturingModel.capturedInputs), "Agent B should NOT be called due to error") assert.Equal(t, int32(1), atomic.LoadInt32(&noRetryModel.callCount), "Model should only be called once (no retry)") } ================================================ FILE: adk/chatmodel_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "testing" "time" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) // TestChatModelAgentRun tests the Run method of ChatModelAgent func TestChatModelAgentRun(t *testing.T) { // Basic test for Run method t.Run("BasicFunctionality", func(t *testing.T) { ctx := context.Background() // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Hello, I am an AI assistant.", nil), nil). Times(1) // Create a ChatModelAgent agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for unit testing", Instruction: "You are a helpful assistant.", Model: cm, }) assert.NoError(t, err) assert.NotNil(t, agent) // Run the agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Hello, who are you?"), }, } iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) // Verify the message content msg := event.Output.MessageOutput.Message assert.NotNil(t, msg) assert.Equal(t, "Hello, I am an AI assistant.", msg.Content) // No more events _, ok = iterator.Next() assert.False(t, ok) }) t.Run("BasicChatModelWithAgentMiddleware", func(t *testing.T) { ctx := context.Background() // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Hello, I am an AI assistant.", nil), nil). Times(1) afterChatModelExecuted := false // Create a ChatModelAgent agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for unit testing", Instruction: "You are a helpful assistant.", Model: cm, Middlewares: []AgentMiddleware{ { BeforeChatModel: func(ctx context.Context, state *ChatModelAgentState) error { state.Messages = append(state.Messages, schema.UserMessage("m")) return nil }, AfterChatModel: func(ctx context.Context, state *ChatModelAgentState) error { assert.Len(t, state.Messages, 4) afterChatModelExecuted = true return nil }, }, }, }) assert.NoError(t, err) assert.NotNil(t, agent) // Run the agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Hello, who are you?"), }, } iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) // Verify the message content msg := event.Output.MessageOutput.Message assert.NotNil(t, msg) assert.Equal(t, "Hello, I am an AI assistant.", msg.Content) // No more events _, ok = iterator.Next() assert.False(t, ok) assert.True(t, afterChatModelExecuted) }) t.Run("AfterChatModel_NoTools_ModifyDoesNotAffectEvent", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("original content", nil), nil). Times(1) var capturedMessages []*schema.Message agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for AfterChatModel NoTools scenario", Instruction: "You are a helpful assistant.", Model: cm, Middlewares: []AgentMiddleware{ { AfterChatModel: func(ctx context.Context, state *ChatModelAgentState) error { capturedMessages = make([]*schema.Message, len(state.Messages)) copy(capturedMessages, state.Messages) state.Messages = append(state.Messages, schema.AssistantMessage("appended content", nil)) return nil }, }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) msg := event.Output.MessageOutput.Message assert.NotNil(t, msg) assert.Equal(t, "original content", msg.Content) _, ok = iterator.Next() assert.False(t, ok) assert.Len(t, capturedMessages, 3) }) t.Run("AfterChatModel_ReAct_ModifyAffectsFlow", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) generateCount := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { generateCount++ if generateCount == 1 { return schema.AssistantMessage("first response with tool call", []schema.ToolCall{ {ID: "tc1", Function: schema.FunctionCall{Name: "test_tool", Arguments: "{}"}}, }), nil } return schema.AssistantMessage("final response", nil), nil }).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() toolCalled := false testTool := &fakeToolForTest{tarCount: 0} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for AfterChatModel ReAct scenario", Instruction: "You are a helpful assistant.", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Middlewares: []AgentMiddleware{ { AfterChatModel: func(ctx context.Context, state *ChatModelAgentState) error { lastMsg := state.Messages[len(state.Messages)-1] if len(lastMsg.ToolCalls) > 0 { toolCalled = true state.Messages[len(state.Messages)-1] = schema.AssistantMessage("modified to remove tool call", nil) } return nil }, }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } assert.True(t, toolCalled) assert.Equal(t, 1, generateCount) assert.Equal(t, 1, len(events)) event := events[0] assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) assert.Equal(t, "first response with tool call", event.Output.MessageOutput.Message.Content) assert.Len(t, event.Output.MessageOutput.Message.ToolCalls, 1) }) t.Run("AfterChatModel_ReAct_AppendToolCall_AffectsFlow", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) generateCount := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { generateCount++ if generateCount == 1 { return schema.AssistantMessage("first response no tool", nil), nil } return schema.AssistantMessage("final response", nil), nil }).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() testTool := &fakeToolForTest{tarCount: 0} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for AfterChatModel ReAct append tool call", Instruction: "You are a helpful assistant.", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Middlewares: []AgentMiddleware{ { AfterChatModel: func(ctx context.Context, state *ChatModelAgentState) error { if generateCount == 1 { state.Messages[len(state.Messages)-1] = schema.AssistantMessage("modified with tool call", []schema.ToolCall{ {ID: "tc1", Function: schema.FunctionCall{Name: "test_tool", Arguments: "{}"}}, }) } return nil }, }, }, }) assert.NoError(t, err) input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello")}, } iterator := agent.Run(ctx, input) var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 2, generateCount) assert.Equal(t, 3, len(events)) event0 := events[0] assert.NotNil(t, event0.Output) assert.NotNil(t, event0.Output.MessageOutput) assert.Equal(t, "first response no tool", event0.Output.MessageOutput.Message.Content) assert.Empty(t, event0.Output.MessageOutput.Message.ToolCalls) event2 := events[2] assert.NotNil(t, event2.Output) assert.NotNil(t, event2.Output.MessageOutput) assert.Equal(t, "final response", event2.Output.MessageOutput.Message.Content) }) // Test with streaming output t.Run("StreamOutput", func(t *testing.T) { ctx := context.Background() // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Create a stream reader for the mock response sr := schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("Hello", nil), schema.AssistantMessage(", I am", nil), schema.AssistantMessage(" an AI assistant.", nil), }) // Set up expectations for the mock model cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(sr, nil). Times(1) // Create a ChatModelAgent agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for unit testing", Instruction: "You are a helpful assistant.", Model: cm, }) assert.NoError(t, err) assert.NotNil(t, agent) // Run the agent with streaming enabled input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello, who are you?")}, EnableStreaming: true, } iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) assert.True(t, event.Output.MessageOutput.IsStreaming) // No more events _, ok = iterator.Next() assert.False(t, ok) }) // Test error handling t.Run("ErrorHandling", func(t *testing.T) { ctx := context.Background() // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model to return an error cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errors.New("model error")). Times(1) // Create a ChatModelAgent agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for unit testing", Instruction: "You are a helpful assistant.", Model: cm, }) assert.NoError(t, err) assert.NotNil(t, agent) // Run the agent input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello, who are you?")}, } iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get the event from the iterator, should contain an error event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.NotNil(t, event.Err) assert.Contains(t, event.Err.Error(), "model error") // No more events _, ok = iterator.Next() assert.False(t, ok) }) // Test with tools t.Run("WithTools", func(t *testing.T) { ctx := context.Background() // Create a fake tool for testing fakeTool := &fakeToolForTest{ tarCount: 1, } info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using tool", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"name": "test user"}`, }, }}), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Task completed", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create a ChatModelAgent with tools agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for unit testing", Instruction: "You are a helpful assistant.", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, }, }) assert.NoError(t, err) assert.NotNil(t, agent) // Run the agent input := &AgentInput{ Messages: []Message{schema.UserMessage("Use the test tool")}, } iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get events from the iterator // First event should be the model output with tool call event1, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event1) assert.Nil(t, event1.Err) assert.NotNil(t, event1.Output) assert.NotNil(t, event1.Output.MessageOutput) assert.Equal(t, schema.Assistant, event1.Output.MessageOutput.Role) // Second event should be the tool output event2, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event2) assert.Nil(t, event2.Err) assert.NotNil(t, event2.Output) assert.NotNil(t, event2.Output.MessageOutput) assert.Equal(t, schema.Tool, event2.Output.MessageOutput.Role) // Third event should be the final model output event3, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event3) assert.Nil(t, event3.Err) assert.NotNil(t, event3.Output) assert.NotNil(t, event3.Output.MessageOutput) assert.Equal(t, schema.Assistant, event3.Output.MessageOutput.Role) // No more events _, ok = iterator.Next() assert.False(t, ok) }) } // TestExitTool tests the Exit tool functionality func TestExitTool(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock chat model cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model // First call: model generates a message with Exit tool call cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("I'll exit with a final result", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: "exit", Arguments: `{"final_result": "This is the final result"}`}, }, }), nil). Times(1) // Model should implement WithTools cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create an agent with the Exit tool agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent with Exit tool", Instruction: "You are a helpful assistant.", Model: cm, Exit: &ExitTool{}, }) assert.NoError(t, err) assert.NotNil(t, agent) // Run the agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Please exit with a final result"), }, } iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // First event: model output with tool call event1, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event1) assert.Nil(t, event1.Err) assert.NotNil(t, event1.Output) assert.NotNil(t, event1.Output.MessageOutput) assert.Equal(t, schema.Assistant, event1.Output.MessageOutput.Role) // Second event: tool output (Exit) event2, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event2) assert.Nil(t, event2.Err) assert.NotNil(t, event2.Output) assert.NotNil(t, event2.Output.MessageOutput) assert.Equal(t, schema.Tool, event2.Output.MessageOutput.Role) // Verify the action is Exit assert.NotNil(t, event2.Action) assert.True(t, event2.Action.Exit) // Verify the final result assert.Equal(t, "This is the final result", event2.Output.MessageOutput.Message.Content) // No more events _, ok = iterator.Next() assert.False(t, ok) } func TestParallelReturnDirectlyToolCall(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock chat model cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model // First call: model generates a message with Exit tool call cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("I'll exit with a final result", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{Name: "tool1"}, }, { ID: "tool-call-2", Function: schema.FunctionCall{Name: "tool2"}, }, { ID: "tool-call-3", Function: schema.FunctionCall{Name: "tool3"}, }, }), nil). Times(1) // Model should implement WithTools cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create an agent with the Exit tool agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent with Exit tool", Instruction: "You are a helpful assistant.", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{ &myTool{name: "tool1", desc: "tool1", waitTime: time.Millisecond}, &myTool{name: "tool2", desc: "tool2", waitTime: 10 * time.Millisecond}, &myTool{name: "tool3", desc: "tool3", waitTime: 100 * time.Millisecond}, }, }, ReturnDirectly: map[string]bool{ "tool1": true, }, }, }) assert.NoError(t, err) assert.NotNil(t, agent) r := NewRunner(ctx, RunnerConfig{ Agent: agent, }) iter := r.Query(ctx, "") times := 0 for { e, ok := iter.Next() if !ok { assert.Equal(t, 4, times) break } if times == 3 { assert.Equal(t, "tool1", e.Output.MessageOutput.Message.ToolName) } times++ } } func TestConcurrentSameToolSendToolGenActionUsesToolCallID(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("tools", []schema.ToolCall{ {ID: "id1", Function: schema.FunctionCall{Name: "action_tool", Arguments: "A"}}, {ID: "id2", Function: schema.FunctionCall{Name: "action_tool", Arguments: "B"}}, }), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil). Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Agent with action tool", Instruction: "", Model: cm, ToolsConfig: ToolsConfig{ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{actionTool{}}}}, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("go")}}) seen := map[string]bool{} for { e, ok := iter.Next() if !ok { break } if e.Output != nil && e.Output.MessageOutput != nil && e.Output.MessageOutput.Message != nil && e.Output.MessageOutput.Message.Role == schema.Tool { if e.Action != nil && e.Action.CustomizedAction != nil { if s, ok := e.Action.CustomizedAction.(string); ok { seen[s] = true } } } } assert.True(t, seen["A"]) assert.True(t, seen["B"]) } type myTool struct { name string desc string waitTime time.Duration } func (m *myTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: m.name, Desc: m.desc, }, nil } func (m *myTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { time.Sleep(m.waitTime) return "success", nil } type actionTool struct{} func (a actionTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "action_tool", Desc: "action tool"}, nil } func (a actionTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { _ = SendToolGenAction(ctx, "action_tool", &AgentAction{CustomizedAction: argumentsInJSON}) return "ok", nil } type streamActionTool struct{} func (s streamActionTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "action_tool_stream", Desc: "action stream tool"}, nil } func (s streamActionTool) StreamableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (*schema.StreamReader[string], error) { _ = SendToolGenAction(ctx, "action_tool_stream", &AgentAction{CustomizedAction: argumentsInJSON}) sr, sw := schema.Pipe[string](1) go func() { defer sw.Close() _ = sw.Send("o", nil) _ = sw.Send("k", nil) }() return sr, nil } type legacyStreamActionTool struct{} func (s legacyStreamActionTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "legacy_action_tool_stream", Desc: "legacy action stream tool"}, nil } func (s legacyStreamActionTool) StreamableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (*schema.StreamReader[string], error) { _ = compose.ProcessState(ctx, func(ctx context.Context, st *State) error { st.setToolGenAction("legacy_action_tool_stream", &AgentAction{CustomizedAction: argumentsInJSON}) return nil }) sr, sw := schema.Pipe[string](1) go func() { defer sw.Close() _ = sw.Send("o", nil) _ = sw.Send("k", nil) }() return sr, nil } // TestChatModelAgentOutputKey tests the outputKey configuration and setOutputToSession function func TestChatModelAgentOutputKey(t *testing.T) { // Test outputKey configuration - stores output in session t.Run("OutputKeyStoresInSession", func(t *testing.T) { for i := 0; i < 1000; i++ { } ctx := context.Background() // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Hello, I am an AI assistant.", nil), nil). Times(1) // Create a ChatModelAgent with outputKey configured agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for unit testing", Instruction: "You are a helpful assistant.", Model: cm, OutputKey: "agent_output", // This should store output in session }) assert.NoError(t, err) assert.NotNil(t, agent) // Initialize a run context to enable session storage input := &AgentInput{ Messages: []Message{ schema.UserMessage("Hello, who are you?"), }, } ctx, runCtx := initRunCtx(ctx, "TestAgent", input) assert.NotNil(t, runCtx) assert.NotNil(t, runCtx.Session) // Run the agent iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) // Verify the message content msg := event.Output.MessageOutput.Message assert.Equal(t, "Hello, I am an AI assistant.", msg.Content) // Verify that the output was stored in the session time.AfterFunc(100*time.Millisecond, func() { sessionValues := GetSessionValues(ctx) assert.Contains(t, sessionValues, "agent_output") assert.Equal(t, "Hello, I am an AI assistant.", sessionValues["agent_output"]) }) // No more events _, ok = iterator.Next() assert.False(t, ok) }) // Test outputKey configuration with streaming - stores concatenated output in session t.Run("OutputKeyWithStreamingStoresInSession", func(t *testing.T) { ctx := context.Background() // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Create a stream reader for the mock response sr := schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("Hello", nil), schema.AssistantMessage(", I am", nil), schema.AssistantMessage(" an AI assistant.", nil), }) // Set up expectations for the mock model cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(sr, nil). Times(1) // Create a ChatModelAgent with outputKey configured agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for unit testing", Instruction: "You are a helpful assistant.", Model: cm, OutputKey: "agent_output", // This should store concatenated stream in session }) assert.NoError(t, err) assert.NotNil(t, agent) // Initialize a run context to enable session storage input := &AgentInput{ Messages: []Message{schema.UserMessage("Hello, who are you?")}, EnableStreaming: true, } ctx, runCtx := initRunCtx(ctx, "TestAgent", input) assert.NotNil(t, runCtx) assert.NotNil(t, runCtx.Session) // Run the agent iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) assert.True(t, event.Output.MessageOutput.IsStreaming) time.AfterFunc(100*time.Millisecond, func() { // Verify that the concatenated output was stored in the session sessionValues := GetSessionValues(ctx) assert.Contains(t, sessionValues, "agent_output") assert.Equal(t, "Hello, I am an AI assistant.", sessionValues["agent_output"]) }) // No more events _, ok = iterator.Next() assert.False(t, ok) }) // Test setOutputToSession function directly - regular message t.Run("SetOutputToSessionRegularMessage", func(t *testing.T) { ctx := context.Background() // Initialize a run context to enable session storage input := &AgentInput{ Messages: []Message{schema.UserMessage("test")}, } ctx, runCtx := initRunCtx(ctx, "TestAgent", input) assert.NotNil(t, runCtx) assert.NotNil(t, runCtx.Session) // Test with a regular message msg := schema.AssistantMessage("Test response", nil) err := setOutputToSession(ctx, msg, nil, "test_output") assert.NoError(t, err) // Verify the message content was stored sessionValues := GetSessionValues(ctx) assert.Contains(t, sessionValues, "test_output") assert.Equal(t, "Test response", sessionValues["test_output"]) }) // Test setOutputToSession function directly - streaming message t.Run("SetOutputToSessionStreamingMessage", func(t *testing.T) { ctx := context.Background() // Initialize a run context to enable session storage input := &AgentInput{ Messages: []Message{schema.UserMessage("test")}, } ctx, runCtx := initRunCtx(ctx, "TestAgent", input) assert.NotNil(t, runCtx) assert.NotNil(t, runCtx.Session) // Test with a streaming message sr := schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("Stream", nil), schema.AssistantMessage(" response", nil), schema.AssistantMessage(" content", nil), }) err := setOutputToSession(ctx, nil, sr, "test_output") assert.NoError(t, err) // Verify the concatenated stream content was stored sessionValues := GetSessionValues(ctx) assert.Contains(t, sessionValues, "test_output") assert.Equal(t, "Stream response content", sessionValues["test_output"]) }) // Test setOutputToSession function directly - error case t.Run("SetOutputToSessionErrorCase", func(t *testing.T) { ctx := context.Background() // Initialize a run context to enable session storage input := &AgentInput{ Messages: []Message{schema.UserMessage("test")}, } ctx, runCtx := initRunCtx(ctx, "TestAgent", input) assert.NotNil(t, runCtx) assert.NotNil(t, runCtx.Session) // Test with an invalid stream (simulate error) // Create a stream that will fail when concatenated sr := schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("test", nil), }) // Close the stream to simulate an error condition sr.Close() // This should return an error because the stream is closed err := setOutputToSession(ctx, nil, sr, "test_output") // Note: The actual behavior may vary depending on the stream implementation // Some streams may not error when closed, so we'll accept either outcome if err != nil { // If there's an error, verify nothing was stored sessionValues := GetSessionValues(ctx) assert.NotContains(t, sessionValues, "test_output") } else { // If no error, verify the content was stored sessionValues := GetSessionValues(ctx) assert.Contains(t, sessionValues, "test_output") assert.Equal(t, "test", sessionValues["test_output"]) } }) // Test outputKey with React workflow (tools enabled) t.Run("OutputKeyWithReactWorkflow", func(t *testing.T) { ctx := context.Background() // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Create a simple tool for testing fakeTool := &fakeToolForTest{ tarCount: 1, } // Set up expectations for the mock model - it will generate a final response cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Final response from React workflow", nil), nil). Times(1) // Model should implement WithTools cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create a ChatModelAgent with outputKey and tools configured agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent with tools", Instruction: "You are a helpful assistant.", Model: cm, OutputKey: "agent_output", ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, }, }) assert.NoError(t, err) assert.NotNil(t, agent) // Initialize a run context to enable session storage input := &AgentInput{ Messages: []Message{schema.UserMessage("Use the tool")}, } ctx, runCtx := initRunCtx(ctx, "TestAgent", input) assert.NotNil(t, runCtx) assert.NotNil(t, runCtx.Session) // Run the agent iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.Nil(t, event.Err) // Verify the message content msg := event.Output.MessageOutput.Message assert.Equal(t, "Final response from React workflow", msg.Content) // Verify that the output was stored in the session time.AfterFunc(time.Millisecond*10, func() { sessionValues := GetSessionValues(ctx) assert.Contains(t, sessionValues, "agent_output") assert.Equal(t, "Final response from React workflow", sessionValues["agent_output"]) }) // No more events _, ok = iterator.Next() assert.False(t, ok) }) } func TestConcurrentSameStreamToolSendToolGenActionUsesToolCallID(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("tools", []schema.ToolCall{ {ID: "sid1", Function: schema.FunctionCall{Name: "action_tool_stream", Arguments: "SA"}}, {ID: "sid2", Function: schema.FunctionCall{Name: "action_tool_stream", Arguments: "SB"}}, }), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil). Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Agent with stream action tool", Instruction: "", Model: cm, ToolsConfig: ToolsConfig{ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{streamActionTool{}}}}, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("go")}}) seen := map[string]bool{} for { e, ok := iter.Next() if !ok { break } if e.Output != nil && e.Output.MessageOutput != nil { if e.Output.MessageOutput.IsStreaming { if e.Action != nil && e.Action.CustomizedAction != nil { if s, ok := e.Action.CustomizedAction.(string); ok { seen[s] = true } } } } } assert.True(t, seen["SA"]) assert.True(t, seen["SB"]) } func TestStreamToolLegacyNameKeyFallback(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("tools", []schema.ToolCall{ {ID: "lsid1", Function: schema.FunctionCall{Name: "legacy_action_tool_stream", Arguments: "LA"}}, }), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil). Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Agent with legacy stream action tool", Instruction: "", Model: cm, ToolsConfig: ToolsConfig{ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{legacyStreamActionTool{}}}}, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("go")}}) found := false for { e, ok := iter.Next() if !ok { break } if e.Output != nil && e.Output.MessageOutput != nil && e.Output.MessageOutput.IsStreaming { if e.Action != nil && e.Action.CustomizedAction != nil { if s, ok := e.Action.CustomizedAction.(string); ok { found = (s == "LA") } } } } assert.True(t, found) } func TestChatModelAgent_ToolResultMiddleware_EmitsFinalResult(t *testing.T) { originalResult := "original_result" modifiedResult := "modified_by_middleware" resultModifyingMiddleware := compose.ToolMiddleware{ Invokable: func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } output.Result = modifiedResult return output, nil } }, Streamable: func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } output.Result = schema.StreamReaderFromArray([]string{modifiedResult}) return output, nil } }, } t.Run("Invoke", func(t *testing.T) { ctx := context.Background() testTool := &simpleToolForMiddlewareTest{name: "test_tool", result: originalResult} ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) info, err := testTool.Info(ctx) assert.NoError(t, err) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"input": "test"}`, }, }, }), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "test_agent", Description: "test agent with middleware", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, ToolCallMiddlewares: []compose.ToolMiddleware{resultModifyingMiddleware}, }, }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false, CheckPointStore: newBridgeStore()}) it := r.Run(ctx, []Message{schema.UserMessage("call the tool")}) var toolResultEvents []*AgentEvent for { ev, ok := it.Next() if !ok { break } if ev.Output != nil && ev.Output.MessageOutput != nil && ev.Output.MessageOutput.Message != nil && ev.Output.MessageOutput.Message.Role == schema.Tool { toolResultEvents = append(toolResultEvents, ev) } } assert.NotEmpty(t, toolResultEvents, "should have at least one tool result event") for _, ev := range toolResultEvents { assert.Equal(t, modifiedResult, ev.Output.MessageOutput.Message.Content, "tool result event should contain the middleware-modified result, not the original") assert.NotEqual(t, originalResult, ev.Output.MessageOutput.Message.Content, "tool result event should NOT contain the original result") } }) t.Run("Stream", func(t *testing.T) { ctx := context.Background() testTool := &simpleToolForMiddlewareTest{name: "test_tool_stream", result: originalResult} ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) info, err := testTool.Info(ctx) assert.NoError(t, err) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"input": "test"}`, }, }, }), }), nil). Times(1) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("final response", nil), }), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "test_agent", Description: "test agent with middleware", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, ToolCallMiddlewares: []compose.ToolMiddleware{resultModifyingMiddleware}, }, }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true, CheckPointStore: newBridgeStore()}) it := r.Run(ctx, []Message{schema.UserMessage("call the tool")}) var toolResultContents []string for { ev, ok := it.Next() if !ok { break } if ev.Output != nil && ev.Output.MessageOutput != nil { if ev.Output.MessageOutput.Message != nil && ev.Output.MessageOutput.Message.Role == schema.Tool { toolResultContents = append(toolResultContents, ev.Output.MessageOutput.Message.Content) } if ev.Output.MessageOutput.IsStreaming && ev.Output.MessageOutput.MessageStream != nil && ev.Output.MessageOutput.Role == schema.Tool { var msgs []*schema.Message for { msg, err := ev.Output.MessageOutput.MessageStream.Recv() if err != nil { break } msgs = append(msgs, msg) } if len(msgs) > 0 { concated, err := schema.ConcatMessages(msgs) if err == nil { toolResultContents = append(toolResultContents, concated.Content) } } } } } assert.NotEmpty(t, toolResultContents, "should have at least one tool result event") for _, content := range toolResultContents { assert.Equal(t, modifiedResult, content, "tool result event should contain the middleware-modified result, not the original") assert.NotEqual(t, originalResult, content, "tool result event should NOT contain the original result") } }) } type simpleToolForMiddlewareTest struct { name string result string } func (s *simpleToolForMiddlewareTest) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: s.name, Desc: "simple tool", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "input": { Desc: "input", Required: true, Type: schema.String, }, }), }, nil } func (s *simpleToolForMiddlewareTest) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return s.result, nil } func (s *simpleToolForMiddlewareTest) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) { return schema.StreamReaderFromArray([]string{s.result}), nil } func TestGetComposeOptions(t *testing.T) { t.Run("WithChatModelOptions", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var capturedTemperature float32 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { options := model.GetCommonOptions(&model.Options{}, opts...) if options.Temperature != nil { capturedTemperature = *options.Temperature } return schema.AssistantMessage("response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, }) assert.NoError(t, err) temp := float32(0.7) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}, WithChatModelOptions([]model.Option{model.WithTemperature(temp)})) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, temp, capturedTemperature, "Temperature should be passed through WithChatModelOptions") }) t.Run("WithToolOptions", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var toolOptionsCaptured bool testTool := &toolOptionCapturingTool{ name: "test_tool", onRun: func(opts []tool.Option) { if len(opts) > 0 { toolOptionsCaptured = true } }, } info, _ := testTool.Info(ctx) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), nil).Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}, WithToolOptions([]tool.Option{testToolOption("test_value")})) for { _, ok := iter.Next() if !ok { break } } assert.True(t, toolOptionsCaptured, "Tool options should be passed through WithToolOptions") }) } type toolOptionCapturingTool struct { name string onRun func(opts []tool.Option) } func (t *toolOptionCapturingTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: t.name, Desc: t.name + " description"}, nil } func (t *toolOptionCapturingTool) InvokableRun(_ context.Context, _ string, opts ...tool.Option) (string, error) { if t.onRun != nil { t.onRun(opts) } return t.name + " result", nil } type testToolOptions struct { value string } func testToolOption(value string) tool.Option { return tool.WrapImplSpecificOptFn(func(o *testToolOptions) { o.value = value }) } type errorTool struct { infoErr error } func (e *errorTool) Info(_ context.Context) (*schema.ToolInfo, error) { return nil, e.infoErr } func (e *errorTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return "", nil } func TestChatModelAgent_PrepareExecContextError(t *testing.T) { t.Run("Run_WithToolInfoError_ReturnsError", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) expectedErr := errors.New("tool info error") errTool := &errorTool{infoErr: expectedErr} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{errTool}, }, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) event, ok := iter.Next() assert.True(t, ok) assert.NotNil(t, event.Err) assert.Contains(t, event.Err.Error(), "tool info error") _, ok = iter.Next() assert.False(t, ok) }) t.Run("Resume_WithToolInfoError_ReturnsError", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) expectedErr := errors.New("tool info error for resume") errTool := &errorTool{infoErr: expectedErr} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{errTool}, }, }, }) assert.NoError(t, err) iter := agent.Resume(ctx, &ResumeInfo{ InterruptState: []byte("dummy"), EnableStreaming: false, }) event, ok := iter.Next() assert.True(t, ok) assert.NotNil(t, event.Err) assert.Contains(t, event.Err.Error(), "tool info error for resume") _, ok = iter.Next() assert.False(t, ok) }) } func TestPreprocessComposeCheckpoint_MigrateErrorIsReturned(t *testing.T) { in := []byte("prefix\u0015" + stateGobNameV080 + "suffix") _, err := preprocessComposeCheckpoint(in) assert.Error(t, err) } ================================================ FILE: adk/config.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import "github.com/cloudwego/eino/adk/internal" // Language represents the language setting for the ADK built-in prompts. type Language = internal.Language const ( // LanguageEnglish represents English language. LanguageEnglish Language = internal.LanguageEnglish // LanguageChinese represents Chinese language. LanguageChinese Language = internal.LanguageChinese ) // SetLanguage sets the language for the ADK built-in prompts. // The default language is English if not explicitly set. func SetLanguage(lang Language) error { return internal.SetLanguage(lang) } ================================================ FILE: adk/deterministic_transfer.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "runtime/debug" "sync" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) func init() { schema.RegisterName[*deterministicTransferState]("_eino_adk_deterministic_transfer_state") } type deterministicTransferState struct { EventList []*agentEventWrapper } // AgentWithDeterministicTransferTo wraps an agent to transfer to given agents deterministically. func AgentWithDeterministicTransferTo(_ context.Context, config *DeterministicTransferConfig) Agent { if ra, ok := config.Agent.(ResumableAgent); ok { return &resumableAgentWithDeterministicTransferTo{ agent: ra, toAgentNames: config.ToAgentNames, } } return &agentWithDeterministicTransferTo{ agent: config.Agent, toAgentNames: config.ToAgentNames, } } type agentWithDeterministicTransferTo struct { agent Agent toAgentNames []string } func (a *agentWithDeterministicTransferTo) Description(ctx context.Context) string { return a.agent.Description(ctx) } func (a *agentWithDeterministicTransferTo) Name(ctx context.Context) string { return a.agent.Name(ctx) } func (a *agentWithDeterministicTransferTo) GetType() string { if typer, ok := a.agent.(components.Typer); ok { return typer.GetType() } return "DeterministicTransfer" } func (a *agentWithDeterministicTransferTo) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { if fa, ok := a.agent.(*flowAgent); ok { return runFlowAgentWithIsolatedSession(ctx, fa, input, a.toAgentNames, options...) } aIter := a.agent.Run(ctx, input, options...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go forwardEventsAndAppendTransfer(aIter, generator, a.toAgentNames) return iterator } type resumableAgentWithDeterministicTransferTo struct { agent ResumableAgent toAgentNames []string } func (a *resumableAgentWithDeterministicTransferTo) Description(ctx context.Context) string { return a.agent.Description(ctx) } func (a *resumableAgentWithDeterministicTransferTo) Name(ctx context.Context) string { return a.agent.Name(ctx) } func (a *resumableAgentWithDeterministicTransferTo) GetType() string { if typer, ok := a.agent.(components.Typer); ok { return typer.GetType() } return "DeterministicTransfer" } func (a *resumableAgentWithDeterministicTransferTo) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { if fa, ok := a.agent.(*flowAgent); ok { return runFlowAgentWithIsolatedSession(ctx, fa, input, a.toAgentNames, options...) } aIter := a.agent.Run(ctx, input, options...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go forwardEventsAndAppendTransfer(aIter, generator, a.toAgentNames) return iterator } func (a *resumableAgentWithDeterministicTransferTo) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { if fa, ok := a.agent.(*flowAgent); ok { return resumeFlowAgentWithIsolatedSession(ctx, fa, info, a.toAgentNames, opts...) } aIter := a.agent.Resume(ctx, info, opts...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go forwardEventsAndAppendTransfer(aIter, generator, a.toAgentNames) return iterator } func forwardEventsAndAppendTransfer(iter *AsyncIterator[*AgentEvent], generator *AsyncGenerator[*AgentEvent], toAgentNames []string) { defer func() { if panicErr := recover(); panicErr != nil { generator.Send(&AgentEvent{Err: safe.NewPanicErr(panicErr, debug.Stack())}) } generator.Close() }() var lastEvent *AgentEvent for { event, ok := iter.Next() if !ok { break } generator.Send(event) lastEvent = event } if lastEvent != nil && lastEvent.Action != nil && (lastEvent.Action.Interrupted != nil || lastEvent.Action.Exit) { return } sendTransferEvents(generator, toAgentNames) } func runFlowAgentWithIsolatedSession(ctx context.Context, fa *flowAgent, input *AgentInput, toAgentNames []string, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { parentSession := getSession(ctx) parentRunCtx := getRunCtx(ctx) isolatedSession := &runSession{ Values: parentSession.Values, valuesMtx: parentSession.valuesMtx, } if isolatedSession.valuesMtx == nil { isolatedSession.valuesMtx = &sync.Mutex{} } if isolatedSession.Values == nil { isolatedSession.Values = make(map[string]any) } ctx = setRunCtx(ctx, &runContext{ Session: isolatedSession, RootInput: parentRunCtx.RootInput, RunPath: parentRunCtx.RunPath, }) iter := fa.Run(ctx, input, options...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go handleFlowAgentEvents(ctx, iter, generator, isolatedSession, parentSession, toAgentNames) return iterator } func resumeFlowAgentWithIsolatedSession(ctx context.Context, fa *flowAgent, info *ResumeInfo, toAgentNames []string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { state, ok := info.InterruptState.(*deterministicTransferState) if !ok || state == nil { return genErrorIter(errors.New("invalid interrupt state for flowAgent resume in deterministic transfer")) } parentSession := getSession(ctx) parentRunCtx := getRunCtx(ctx) isolatedSession := &runSession{ Values: parentSession.Values, valuesMtx: parentSession.valuesMtx, Events: state.EventList, } if isolatedSession.valuesMtx == nil { isolatedSession.valuesMtx = &sync.Mutex{} } if isolatedSession.Values == nil { isolatedSession.Values = make(map[string]any) } ctx = setRunCtx(ctx, &runContext{ Session: isolatedSession, RootInput: parentRunCtx.RootInput, RunPath: parentRunCtx.RunPath, }) iter := fa.Resume(ctx, info, opts...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go handleFlowAgentEvents(ctx, iter, generator, isolatedSession, parentSession, toAgentNames) return iterator } func handleFlowAgentEvents(ctx context.Context, iter *AsyncIterator[*AgentEvent], generator *AsyncGenerator[*AgentEvent], isolatedSession, parentSession *runSession, toAgentNames []string) { defer func() { if panicErr := recover(); panicErr != nil { generator.Send(&AgentEvent{Err: safe.NewPanicErr(panicErr, debug.Stack())}) } generator.Close() }() var lastEvent *AgentEvent for { event, ok := iter.Next() if !ok { break } if parentSession != nil && (event.Action == nil || event.Action.Interrupted == nil) { copied := copyAgentEvent(event) setAutomaticClose(copied) setAutomaticClose(event) parentSession.addEvent(copied) } if event.Action != nil && event.Action.internalInterrupted != nil { lastEvent = event continue } generator.Send(event) lastEvent = event } if lastEvent != nil && lastEvent.Action != nil { if lastEvent.Action.internalInterrupted != nil { events := isolatedSession.getEvents() state := &deterministicTransferState{EventList: events} compositeEvent := CompositeInterrupt(ctx, "deterministic transfer wrapper interrupted", state, lastEvent.Action.internalInterrupted) generator.Send(compositeEvent) return } if lastEvent.Action.Exit { return } } sendTransferEvents(generator, toAgentNames) } func sendTransferEvents(generator *AsyncGenerator[*AgentEvent], toAgentNames []string) { for _, toAgentName := range toAgentNames { aMsg, tMsg := GenTransferMessages(context.Background(), toAgentName) aEvent := EventFromMessage(aMsg, nil, schema.Assistant, "") generator.Send(aEvent) tEvent := EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) tEvent.Action = &AgentAction{ TransferToAgent: &TransferToAgentAction{ DestAgentName: toAgentName, }, } generator.Send(tEvent) } } ================================================ FILE: adk/deterministic_transfer_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) type dtTestStore struct { data map[string][]byte } func newDTTestStore() *dtTestStore { return &dtTestStore{data: make(map[string][]byte)} } func (s *dtTestStore) Set(_ context.Context, key string, value []byte) error { s.data[key] = value return nil } func (s *dtTestStore) Get(_ context.Context, key string) ([]byte, bool, error) { v, ok := s.data[key] return v, ok, nil } type dtTestAgent struct { name string runFn func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] resumeFn func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } func (a *dtTestAgent) Name(_ context.Context) string { return a.name } func (a *dtTestAgent) Description(_ context.Context) string { return a.name + " description" } func (a *dtTestAgent) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { return a.runFn(ctx, input, options...) } func (a *dtTestAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { if a.resumeFn != nil { return a.resumeFn(ctx, info, opts...) } return a.runFn(ctx, &AgentInput{}, opts...) } func TestDeterministicTransferFlowAgentInterruptResume(t *testing.T) { ctx := context.Background() store := newDTTestStore() interruptData := "interrupt_data" var runCount int innerAgent := &dtTestAgent{ name: "inner", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { runCount++ iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() gen.Send(EventFromMessage(schema.AssistantMessage("before interrupt", nil), nil, schema.Assistant, "")) intEvent := Interrupt(ctx, interruptData) intEvent.Action.Interrupted.Data = interruptData gen.Send(intEvent) }() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { runCount++ assert.True(t, info.WasInterrupted, "innerAgent resumeFn: should be interrupted") assert.True(t, info.IsResumeTarget, "innerAgent resumeFn: should be resume target") runCtx := getRunCtx(ctx) assert.NotNil(t, runCtx, "innerAgent resumeFn: runCtx should not be nil") assert.NotNil(t, runCtx.Session, "innerAgent resumeFn: runCtx.Session should not be nil") var agentEvents []*AgentEvent for _, ev := range runCtx.Session.Events { if ev.AgentEvent != nil { agentEvents = append(agentEvents, ev.AgentEvent) } } assert.Len(t, agentEvents, 1, "innerAgent resumeFn: should have exactly 1 agent event") if len(agentEvents) == 1 { ev := agentEvents[0] assert.Equal(t, "inner", ev.AgentName, "innerAgent resumeFn: event should be from inner agent") assert.Equal(t, "before interrupt", ev.Output.MessageOutput.Message.Content, "innerAgent resumeFn: event content should be 'before interrupt'") assert.Len(t, ev.RunPath, 2, "innerAgent resumeFn: RunPath should have 2 steps (outer agent, inner agent)") if len(ev.RunPath) == 2 { assert.Equal(t, "outer", ev.RunPath[0].agentName, "innerAgent resumeFn: RunPath[0] should be outer agent") assert.Equal(t, "inner", ev.RunPath[1].agentName, "innerAgent resumeFn: RunPath[1] should be inner agent") } } iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() gen.Send(EventFromMessage(schema.AssistantMessage("after resume", nil), nil, schema.Assistant, "")) }() return iter }, } innerFlowAgent := toFlowAgent(ctx, innerAgent) wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: innerFlowAgent, ToAgentNames: []string{"next_agent"}, }) outerAgent := &dtTestAgent{ name: "outer", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { return wrapped.Run(ctx, input, options...) }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.True(t, info.WasInterrupted, "outerAgent resumeFn: should be interrupted") runCtx := getRunCtx(ctx) assert.NotNil(t, runCtx, "outerAgent resumeFn: runCtx should not be nil") assert.NotNil(t, runCtx.Session, "outerAgent resumeFn: runCtx.Session should not be nil") var agentEvents []*AgentEvent for _, ev := range runCtx.Session.Events { if ev.AgentEvent != nil { agentEvents = append(agentEvents, ev.AgentEvent) } } assert.Len(t, agentEvents, 1, "outerAgent resumeFn: should have exactly 1 agent event") if len(agentEvents) == 1 { ev := agentEvents[0] assert.Equal(t, "inner", ev.AgentName, "outerAgent resumeFn: event should be from inner agent (preserved original)") assert.Equal(t, "before interrupt", ev.Output.MessageOutput.Message.Content, "outerAgent resumeFn: event content should be 'before interrupt'") assert.Len(t, ev.RunPath, 2, "outerAgent resumeFn: RunPath should have 2 steps") if len(ev.RunPath) == 2 { assert.Equal(t, "outer", ev.RunPath[0].agentName, "outerAgent resumeFn: RunPath[0] should be outer agent") assert.Equal(t, "inner", ev.RunPath[1].agentName, "outerAgent resumeFn: RunPath[1] should be inner agent") } } ra := wrapped.(ResumableAgent) return ra.Resume(ctx, info, opts...) }, } outerFlowAgent := toFlowAgent(ctx, outerAgent) runner := NewRunner(ctx, RunnerConfig{ Agent: outerFlowAgent, EnableStreaming: true, CheckPointStore: store, }) iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, WithCheckPointID("cp1")) var events []*AgentEvent var interrupted bool var interruptEvent *AgentEvent for { ev, ok := iter.Next() if !ok { break } events = append(events, ev) if ev.Action != nil && ev.Action.Interrupted != nil { interrupted = true interruptEvent = ev } } assert.Equal(t, 1, runCount, "run should have been called once") assert.True(t, interrupted, "should have interrupted") assert.Greater(t, len(events), 0, "should have events") if interruptEvent == nil { t.Fatal("should have interrupt event") } assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts, "should have interrupt contexts") _, exists, err := store.Get(ctx, "cp1") assert.NoError(t, err) assert.True(t, exists, "checkpoint should have been saved") var hasDeterministicTransferContext bool for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { t.Logf("InterruptContext: ID=%s, Info=%v, IsRootCause=%v, Addr=%v", intCtx.ID, intCtx.Info, intCtx.IsRootCause, intCtx.Address) if intCtx.Info == "deterministic transfer wrapper interrupted" { hasDeterministicTransferContext = true } for parent := intCtx.Parent; parent != nil; parent = parent.Parent { t.Logf(" Parent: ID=%s, Info=%v, Addr=%v", parent.ID, parent.Info, parent.Address) if parent.Info == "deterministic transfer wrapper interrupted" { hasDeterministicTransferContext = true } } } assert.True(t, hasDeterministicTransferContext, "should have deterministic transfer interrupt context") var rootCauseID string for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { if intCtx.IsRootCause { rootCauseID = intCtx.ID break } } assert.NotEmpty(t, rootCauseID, "should have root cause interrupt ID") resumeIter, err := runner.ResumeWithParams(ctx, "cp1", &ResumeParams{ Targets: map[string]any{rootCauseID: nil}, }) assert.NoError(t, err) var resumeEvents []*AgentEvent var resumeErr error var hasTransfer bool for { ev, ok := resumeIter.Next() if !ok { break } if ev.Err != nil { resumeErr = ev.Err } if ev.Action != nil && ev.Action.TransferToAgent != nil { hasTransfer = true } resumeEvents = append(resumeEvents, ev) } assert.Equal(t, 2, runCount, "inner agent should be called twice (once for initial, once for resume)") assert.NotEmpty(t, resumeEvents, "should have resume events") assert.True(t, hasTransfer, "should have transfer action after resume") assert.Error(t, resumeErr, "transfer should fail because next_agent doesn't exist") assert.Contains(t, resumeErr.Error(), "next_agent", "error should mention the missing agent") } func TestDeterministicTransferRunPathPreserved(t *testing.T) { ctx := context.Background() store := newDTTestStore() var collectedRunPaths [][]RunStep innerAgent := &dtTestAgent{ name: "inner", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := EventFromMessage(schema.AssistantMessage("from inner", nil), nil, schema.Assistant, "") gen.Send(ev) }() return iter }, } innerFlowAgent := toFlowAgent(ctx, innerAgent) wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: innerFlowAgent, ToAgentNames: []string{}, }) outerAgent := &dtTestAgent{ name: "outer", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { innerIter := wrapped.Run(ctx, input, options...) iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() for { ev, ok := innerIter.Next() if !ok { break } collectedRunPaths = append(collectedRunPaths, ev.RunPath) gen.Send(ev) } }() return iter }, } outerFlowAgent := toFlowAgent(ctx, outerAgent) runner := NewRunner(ctx, RunnerConfig{ Agent: outerFlowAgent, EnableStreaming: true, CheckPointStore: store, }) iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, WithCheckPointID("cp1")) for { _, ok := iter.Next() if !ok { break } } assert.NotEmpty(t, collectedRunPaths, "should have collected RunPaths") for _, rp := range collectedRunPaths { assert.Len(t, rp, 2, "RunPath should have 2 steps (outer agent, inner agent)") if len(rp) == 2 { assert.Equal(t, "outer", rp[0].agentName, "RunPath[0] should be outer agent") assert.Equal(t, "inner", rp[1].agentName, "RunPath[1] should be inner agent") } } } func TestDeterministicTransferExitSkipsTransfer(t *testing.T) { ctx := context.Background() store := newDTTestStore() innerAgent := &dtTestAgent{ name: "inner", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := EventFromMessage(schema.AssistantMessage("inner exits", nil), nil, schema.Assistant, "") ev.Action = &AgentAction{Exit: true} gen.Send(ev) }() return iter }, } innerFlowAgent := toFlowAgent(ctx, innerAgent) wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: innerFlowAgent, ToAgentNames: []string{"next_agent"}, }) var outerSawExit bool var transferGenerated bool outerAgent := &dtTestAgent{ name: "outer", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { innerIter := wrapped.Run(ctx, input, options...) iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() for { ev, ok := innerIter.Next() if !ok { break } if ev.Action != nil && ev.Action.Exit { outerSawExit = true } if ev.Action != nil && ev.Action.TransferToAgent != nil { transferGenerated = true } gen.Send(ev) } }() return iter }, } outerFlowAgent := toFlowAgent(ctx, outerAgent) runner := NewRunner(ctx, RunnerConfig{ Agent: outerFlowAgent, EnableStreaming: true, CheckPointStore: store, }) iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, WithCheckPointID("cp1")) for { _, ok := iter.Next() if !ok { break } } assert.True(t, outerSawExit, "outer should see exit event from inner") assert.False(t, transferGenerated, "transfer should not be generated when inner exits") } type nonFlowTestAgent struct { name string runFn func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] resumeFn func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } func (a *nonFlowTestAgent) Name(_ context.Context) string { return a.name } func (a *nonFlowTestAgent) Description(_ context.Context) string { return a.name + " description" } func (a *nonFlowTestAgent) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { return a.runFn(ctx, input, options...) } func (a *nonFlowTestAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { if a.resumeFn != nil { return a.resumeFn(ctx, info, opts...) } return a.runFn(ctx, &AgentInput{}, opts...) } type nonResumableTestAgent struct { name string runFn func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] } func (a *nonResumableTestAgent) Name(_ context.Context) string { return a.name } func (a *nonResumableTestAgent) Description(_ context.Context) string { return a.name + " description" } func (a *nonResumableTestAgent) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { return a.runFn(ctx, input, options...) } func TestDeterministicTransferNonFlowAgent_ExitSkipsTransfer(t *testing.T) { ctx := context.Background() agent := &nonFlowTestAgent{ name: "test_agent", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := EventFromMessage(schema.AssistantMessage("exiting", nil), nil, schema.Assistant, "") ev.Action = &AgentAction{Exit: true} gen.Send(ev) }() return iter }, } wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: agent, ToAgentNames: []string{"next_agent"}, }) iter := wrapped.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) var events []*AgentEvent var sawExit bool var sawTransfer bool for { ev, ok := iter.Next() if !ok { break } events = append(events, ev) if ev.Action != nil && ev.Action.Exit { sawExit = true } if ev.Action != nil && ev.Action.TransferToAgent != nil { sawTransfer = true } } assert.True(t, sawExit, "should see exit event") assert.False(t, sawTransfer, "should NOT see transfer when exit is last event") assert.Len(t, events, 1, "should have exactly 1 event (exit)") } func TestDeterministicTransferNonFlowAgent_AppendsTransfer(t *testing.T) { ctx := context.Background() agent := &nonFlowTestAgent{ name: "test_agent", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := EventFromMessage(schema.AssistantMessage("normal output", nil), nil, schema.Assistant, "") gen.Send(ev) }() return iter }, } wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: agent, ToAgentNames: []string{"next_agent"}, }) iter := wrapped.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) var events []*AgentEvent var sawTransfer bool var transferTarget string for { ev, ok := iter.Next() if !ok { break } events = append(events, ev) if ev.Action != nil && ev.Action.TransferToAgent != nil { sawTransfer = true transferTarget = ev.Action.TransferToAgent.DestAgentName } } assert.True(t, sawTransfer, "should see transfer event after normal completion") assert.Equal(t, "next_agent", transferTarget, "transfer target should be next_agent") assert.Greater(t, len(events), 1, "should have more than 1 event (output + transfer messages)") } func TestDeterministicTransferNonFlowAgent_InterruptSkipsTransfer(t *testing.T) { ctx := context.Background() agent := &nonFlowTestAgent{ name: "test_agent", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := &AgentEvent{ Action: &AgentAction{ Interrupted: &InterruptInfo{Data: "test interrupt"}, }, } gen.Send(ev) }() return iter }, } wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: agent, ToAgentNames: []string{"next_agent"}, }) iter := wrapped.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) var events []*AgentEvent var sawInterrupt bool var sawTransfer bool for { ev, ok := iter.Next() if !ok { break } events = append(events, ev) if ev.Action != nil && ev.Action.Interrupted != nil { sawInterrupt = true } if ev.Action != nil && ev.Action.TransferToAgent != nil { sawTransfer = true } } assert.True(t, sawInterrupt, "should see interrupt event") assert.False(t, sawTransfer, "should NOT see transfer when interrupted") assert.Len(t, events, 1, "should have exactly 1 event (interrupt)") } func TestDeterministicTransferNonFlowAgent_Resume(t *testing.T) { ctx := context.Background() var resumeCalled bool agent := &nonFlowTestAgent{ name: "test_agent", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := EventFromMessage(schema.AssistantMessage("from run", nil), nil, schema.Assistant, "") gen.Send(ev) }() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { resumeCalled = true iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := EventFromMessage(schema.AssistantMessage("from resume", nil), nil, schema.Assistant, "") gen.Send(ev) }() return iter }, } wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: agent, ToAgentNames: []string{"next_agent"}, }) ra, ok := wrapped.(ResumableAgent) assert.True(t, ok, "wrapped agent should be ResumableAgent") iter := ra.Resume(ctx, &ResumeInfo{WasInterrupted: true}) var events []*AgentEvent var sawTransfer bool for { ev, ok := iter.Next() if !ok { break } events = append(events, ev) if ev.Action != nil && ev.Action.TransferToAgent != nil { sawTransfer = true } } assert.True(t, resumeCalled, "resume should have been called on inner agent") assert.True(t, sawTransfer, "should see transfer event after resume completion") } func TestDeterministicTransferFlowAgent_ResumeWithInvalidState(t *testing.T) { ctx := context.Background() innerAgent := &dtTestAgent{ name: "inner", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() gen.Send(EventFromMessage(schema.AssistantMessage("test", nil), nil, schema.Assistant, "")) }() return iter }, } innerFlowAgent := toFlowAgent(ctx, innerAgent) wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: innerFlowAgent, ToAgentNames: []string{"next_agent"}, }) ra, ok := wrapped.(ResumableAgent) assert.True(t, ok, "wrapped flowAgent should be ResumableAgent") iter := ra.Resume(ctx, &ResumeInfo{ WasInterrupted: true, InterruptState: nil, }) var gotError bool var errorMsg string for { ev, ok := iter.Next() if !ok { break } if ev.Err != nil { gotError = true errorMsg = ev.Err.Error() } } assert.True(t, gotError, "should get error for invalid state") assert.Contains(t, errorMsg, "invalid interrupt state", "error should mention invalid state") } func TestDeterministicTransferNonResumableAgent(t *testing.T) { ctx := context.Background() agent := &nonResumableTestAgent{ name: "non_resumable", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer gen.Close() ev := EventFromMessage(schema.AssistantMessage("output", nil), nil, schema.Assistant, "") gen.Send(ev) }() return iter }, } wrapped := AgentWithDeterministicTransferTo(ctx, &DeterministicTransferConfig{ Agent: agent, ToAgentNames: []string{"next_agent"}, }) _, isResumable := wrapped.(ResumableAgent) assert.False(t, isResumable, "wrapped non-resumable agent should NOT be ResumableAgent") assert.Equal(t, "non_resumable", wrapped.Name(ctx), "Name should delegate to inner agent") assert.Equal(t, "non_resumable description", wrapped.Description(ctx), "Description should delegate to inner agent") iter := wrapped.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) var sawTransfer bool for { ev, ok := iter.Next() if !ok { break } if ev.Action != nil && ev.Action.TransferToAgent != nil { sawTransfer = true } } assert.True(t, sawTransfer, "should see transfer event") } ================================================ FILE: adk/filesystem/backend.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package filesystem provides file system operations. package filesystem import ( "context" "github.com/cloudwego/eino/schema" ) // FileInfo represents basic file metadata information. type FileInfo struct { // Path is the path of the file or directory, which can be a filename, relative path, or absolute path. Path string // IsDir indicates whether the entry is a directory. // true for directories, false for regular files. IsDir bool // Size is the file size in bytes. // For directories, this value may be 0 or platform-dependent. Size int64 // ModifiedAt is the last modification time in ISO 8601 format. // Example: "2025-01-15T10:30:00Z" ModifiedAt string } // GrepMatch represents a single pattern match result. type GrepMatch struct { Content string // Path is the file path where the match was found. Path string // Line is the 1-based line number of the match. Line int } // LsInfoRequest contains parameters for listing file information. type LsInfoRequest struct { // Path specifies the directory path to list. Path string } // ReadRequest contains parameters for reading file content. type ReadRequest struct { // FilePath is the path to the file to be read. FilePath string // Offset specifies the starting line number (1-based) for reading. // Line 1 is the first line of the file. // Use this when the file is too large to read at once. // Defaults to 1 (start from the first line). // Values < 1 will be treated as 1. Offset int // Limit specifies the maximum number of lines to read. // Use this when the file is too large to read at once. // Defaults to 2000 if not provided or non-positive (<= 0). Limit int } // GrepRequest contains parameters for searching file content. type GrepRequest struct { // ===== Search Parameters ===== // Pattern is the search pattern, supports full regular expression syntax. // Uses ripgrep syntax (not grep). Examples: // - "log.*Error" matches lines with "log" followed by "Error" // - "function\\s+\\w+" matches "function" followed by whitespace and word characters // - Literal braces need escaping: "interface\\{\\}" matches "interface{}" Pattern string // Path is an optional directory path to limit the search scope. Path string // ===== File Filtering ===== // Glob is an optional pattern to filter the files to be searched. // It filters by file path, not content. If empty, no files are filtered. // Supports standard glob wildcards: // - `*` matches any characters except path separators. // - `**` matches any directories recursively. // - `?` matches a single character. // - `[abc]` matches one character from the set. Glob string // FileType is the file type filter, e.g., "js", "py", "rust". // More efficient than Glob for standard file types. FileType string // ===== Search Options ===== // CaseInsensitive enables case insensitive search. CaseInsensitive bool // EnableMultiline enables multiline mode where patterns can span lines. // Default: false (patterns match within single lines only). EnableMultiline bool // ===== Context Display (Content mode only) ===== // AfterLines shows N lines after each match. // Only applicable when OutputMode is "content". // Values <= 0 are treated as unset. AfterLines int // BeforeLines shows N lines before each match. // Only applicable when OutputMode is "content". // Values <= 0 are treated as unset. BeforeLines int } // GlobInfoRequest contains parameters for glob pattern matching. type GlobInfoRequest struct { // Pattern is the glob expression used to match file paths. // It supports standard glob syntax: // - `*` matches any characters except path separators. // - `**` matches any directories recursively. // - `?` matches a single character. // - `[abc]` matches one character from the set. Pattern string // Path is the base directory from which to start the search. Path string } // WriteRequest contains parameters for writing file content. type WriteRequest struct { // FilePath is the path of the file to write. FilePath string // Content is the data to be written to the file. Content string } // EditRequest contains parameters for editing file content. type EditRequest struct { // FilePath is the path of the file to edit. FilePath string // OldString is the exact string to be replaced. It must be non-empty and will be matched literally, including whitespace. OldString string // NewString is the string that will replace OldString. // It must be different from OldString. // An empty string can be used to effectively delete OldString. NewString string // ReplaceAll controls the replacement behavior. // If true, all occurrences of OldString are replaced. // If false, the operation fails unless OldString appears exactly once in the file. ReplaceAll bool } type FileContent struct { Content string } // Backend is a pluggable, unified file backend protocol interface. // // All methods use struct-based parameters to allow future extensibility // without breaking backward compatibility. type Backend interface { // LsInfo lists file information under the given path. // // Returns: // - []FileInfo: List of matching file information // - error: Error if the operation fails LsInfo(ctx context.Context, req *LsInfoRequest) ([]FileInfo, error) // Read reads file content with support for line-based offset and limit. // // Returns: // - string: The file content read // - error: Error if file does not exist or read fails Read(ctx context.Context, req *ReadRequest) (*FileContent, error) // GrepRaw searches for content matching the specified pattern in files. // // Returns: // - []GrepMatch: List of all matching results // - error: Error if the search fails GrepRaw(ctx context.Context, req *GrepRequest) ([]GrepMatch, error) // GlobInfo returns file information matching the glob pattern. // // Returns: // - []FileInfo: List of matching file information // - error: Error if the pattern is invalid or operation fails GlobInfo(ctx context.Context, req *GlobInfoRequest) ([]FileInfo, error) // Write creates or updates file content. // // Returns: // - error: Error if the write operation fails Write(ctx context.Context, req *WriteRequest) error // Edit replaces string occurrences in a file. // // Returns: // - error: Error if file does not exist, OldString is empty, or OldString is not found Edit(ctx context.Context, req *EditRequest) error } // ExecuteRequest contains parameters for executing a command. type ExecuteRequest struct { Command string // The command to execute RunInBackendGround bool } // ExecuteResponse contains the response result of command execution. type ExecuteResponse struct { Output string // Command output content ExitCode *int // Command exit code Truncated bool // Whether the output was truncated } type Shell interface { Execute(ctx context.Context, input *ExecuteRequest) (result *ExecuteResponse, err error) } type StreamingShell interface { ExecuteStreaming(ctx context.Context, input *ExecuteRequest) (result *schema.StreamReader[*ExecuteResponse], err error) } ================================================ FILE: adk/filesystem/backend_inmemory.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filesystem import ( "context" "fmt" "path/filepath" "regexp" "strings" "sync" "time" "github.com/bmatcuk/doublestar/v4" ) type fileEntry struct { content string modifiedAt time.Time } // InMemoryBackend is an in-memory implementation of the Backend interface. // It stores files in a map and is safe for concurrent use. type InMemoryBackend struct { mu sync.RWMutex files map[string]*fileEntry } // NewInMemoryBackend creates a new in-memory backend. func NewInMemoryBackend() *InMemoryBackend { return &InMemoryBackend{ files: make(map[string]*fileEntry), } } // LsInfo lists file information under the given path. func (b *InMemoryBackend) LsInfo(ctx context.Context, req *LsInfoRequest) ([]FileInfo, error) { b.mu.RLock() defer b.mu.RUnlock() // Normalize path path := normalizePath(req.Path) var result []FileInfo seen := make(map[string]bool) dirInfo := make(map[string]*FileInfo) for filePath, entry := range b.files { normalizedFilePath := normalizePath(filePath) // Check if file is under the given path if path == "/" || strings.HasPrefix(normalizedFilePath, path+"/") || normalizedFilePath == path { // For directory listing, we want to show immediate children relativePath := strings.TrimPrefix(normalizedFilePath, path) relativePath = strings.TrimPrefix(relativePath, "/") if relativePath == "" { // The path itself is a file if !seen[normalizedFilePath] { result = append(result, FileInfo{ Path: filepath.Base(normalizedFilePath), IsDir: false, Size: int64(len(entry.content)), ModifiedAt: entry.modifiedAt.Format(time.RFC3339Nano), }) seen[normalizedFilePath] = true } continue } // Get the first segment (immediate child) parts := strings.SplitN(relativePath, "/", 2) if len(parts) > 0 { childPath := path if path != "/" { childPath += "/" } childPath += parts[0] isDir := len(parts) > 1 if !seen[childPath] { if isDir { dirInfo[childPath] = &FileInfo{ Path: parts[0], IsDir: true, Size: 0, ModifiedAt: entry.modifiedAt.Format(time.RFC3339Nano), } } else { result = append(result, FileInfo{ Path: parts[0], IsDir: false, Size: int64(len(entry.content)), ModifiedAt: entry.modifiedAt.Format(time.RFC3339Nano), }) } seen[childPath] = true } else if isDir { if info, ok := dirInfo[childPath]; ok { if entry.modifiedAt.After(mustParseTime(info.ModifiedAt)) { info.ModifiedAt = entry.modifiedAt.Format(time.RFC3339Nano) } } } } } } for _, info := range dirInfo { result = append(result, *info) } return result, nil } func mustParseTime(s string) time.Time { t, _ := time.Parse(time.RFC3339Nano, s) return t } // Read reads file content with offset and limit. func (b *InMemoryBackend) Read(ctx context.Context, req *ReadRequest) (*FileContent, error) { b.mu.RLock() defer b.mu.RUnlock() filePath := normalizePath(req.FilePath) entry, exists := b.files[filePath] if !exists { return nil, fmt.Errorf("file not found: %s", filePath) } // Convert 1-based offset to 0-based index; values < 1 default to line 1 offset := req.Offset - 1 if offset < 0 { offset = 0 } limit := req.Limit if limit <= 0 { limit = 2000 } content := entry.content // Fast path: no offset, content fits within limit — return as-is if offset == 0 { lineCount := strings.Count(content, "\n") + 1 if lineCount <= limit { return &FileContent{Content: content}, nil } } // Skip `offset` lines by scanning for newlines directly start := 0 for i := 0; i < offset; i++ { idx := strings.IndexByte(content[start:], '\n') if idx == -1 { // offset exceeds total lines return &FileContent{}, nil } start += idx + 1 } // Find the end position after `limit` lines end := start for i := 0; i < limit; i++ { idx := strings.IndexByte(content[end:], '\n') if idx == -1 { // Reached the end of content return &FileContent{Content: content[start:]}, nil } end += idx + 1 } // Trim the trailing newline from the last included line return &FileContent{Content: content[start : end-1]}, nil } // GrepRaw returns matches for the given pattern. func (b *InMemoryBackend) GrepRaw(ctx context.Context, req *GrepRequest) ([]GrepMatch, error) { b.mu.RLock() defer b.mu.RUnlock() if req.Pattern == "" { return nil, fmt.Errorf("pattern cannot be empty") } re, err := b.compilePattern(req) if err != nil { return nil, err } searchPath := "/" if req.Path != "" { searchPath = normalizePath(req.Path) } filteredFiles, err := b.filterFiles(searchPath, req) if err != nil { return nil, err } if len(filteredFiles) == 0 { return []GrepMatch{}, nil } if len(filteredFiles) == 1 { collector := newGrepCollector() entry := b.files[filteredFiles[0]] collector.processFile(filteredFiles[0], entry.content, re, req) return collector.buildResults(b, req) } matches, err := b.grepFilesInParallel(filteredFiles, re, req) if err != nil { return nil, err } if req.BeforeLines > 0 || req.AfterLines > 0 { matches = b.applyContext(matches, req) } return matches, nil } func (b *InMemoryBackend) grepFilesInParallel(filteredFiles []string, re *regexp.Regexp, req *GrepRequest) ([]GrepMatch, error) { numWorkers := len(filteredFiles) if numWorkers > 10 { numWorkers = 10 } type fileTask struct { path string content string } tasks := make(chan fileTask, len(filteredFiles)) results := make(chan []GrepMatch, len(filteredFiles)) errChan := make(chan error, numWorkers) var wg sync.WaitGroup for i := 0; i < numWorkers; i++ { wg.Add(1) go func(workerID int) { defer wg.Done() defer func() { if r := recover(); r != nil { errChan <- fmt.Errorf("worker %d panic: %v", workerID, r) } }() collector := newGrepCollector() for task := range tasks { fileMatches := collector.findMatches(task.path, task.content, re, req) if len(fileMatches) > 0 { results <- fileMatches } } }(i) } for _, filePath := range filteredFiles { entry := b.files[filePath] tasks <- fileTask{ path: filePath, content: entry.content, } } close(tasks) go func() { wg.Wait() close(results) close(errChan) }() var allMatches []GrepMatch var errs []error for { select { case matches, ok := <-results: if !ok { results = nil } else { allMatches = append(allMatches, matches...) } case err, ok := <-errChan: if !ok { errChan = nil } else if err != nil { errs = append(errs, err) } } if results == nil && errChan == nil { break } } if len(errs) > 0 { return nil, fmt.Errorf("grep failed with %d error(s): %v", len(errs), errs[0]) } return allMatches, nil } func (b *InMemoryBackend) compilePattern(req *GrepRequest) (*regexp.Regexp, error) { pattern := req.Pattern if req.CaseInsensitive { pattern = "(?i)" + pattern } re, err := regexp.Compile(pattern) if err != nil { return nil, fmt.Errorf("invalid regex pattern: %w", err) } return re, nil } func (b *InMemoryBackend) filterFiles(searchPath string, req *GrepRequest) ([]string, error) { var candidateFiles []string for filePath := range b.files { normalizedFilePath := normalizePath(filePath) if searchPath != "/" && !strings.HasPrefix(normalizedFilePath, searchPath+"/") && normalizedFilePath != searchPath { continue } candidateFiles = append(candidateFiles, normalizedFilePath) } if req.Glob != "" { filtered, err := b.filterByGlob(candidateFiles, searchPath, req.Glob) if err != nil { return nil, err } candidateFiles = filtered } if req.FileType != "" { candidateFiles = b.filterByFileType(candidateFiles, req.FileType) } return candidateFiles, nil } func (b *InMemoryBackend) filterByGlob(files []string, searchPath string, globPattern string) ([]string, error) { var result []string for _, filePath := range files { var matchPath string if strings.Contains(globPattern, "/") || strings.Contains(globPattern, "**") { if searchPath == "/" { matchPath = strings.TrimPrefix(filePath, "/") } else { matchPath = strings.TrimPrefix(filePath, searchPath+"/") } } else { matchPath = filepath.Base(filePath) } matched, err := doublestar.Match(globPattern, matchPath) if err != nil { return nil, fmt.Errorf("invalid glob pattern: %w", err) } if matched { result = append(result, filePath) } } return result, nil } func (b *InMemoryBackend) filterByFileType(files []string, fileType string) []string { var result []string for _, filePath := range files { ext := strings.TrimPrefix(filepath.Ext(filePath), ".") if matchFileType(ext, fileType) { result = append(result, filePath) } } return result } // matchFileType checks if the file extension matches the given file type. func matchFileType(ext, fileType string) bool { typeMap := map[string][]string{ "ada": {"adb", "ads"}, "agda": {"agda", "lagda"}, "aidl": {"aidl"}, "amake": {"bp", "mk"}, "asciidoc": {"adoc", "asc", "asciidoc"}, "asm": {"S", "asm", "s"}, "asp": {"ascx", "asp", "aspx"}, "ats": {"ats", "dats", "hats", "sats"}, "avro": {"avdl", "avpr", "avsc"}, "awk": {"awk"}, "bat": {"bat"}, "bazel": {"BUILD", "bazel", "bzl"}, "bitbake": {"bb", "bbappend", "bbclass", "conf", "inc"}, "c": {"c", "h", "H", "cats"}, "cabal": {"cabal"}, "cbor": {"cbor"}, "ceylon": {"ceylon"}, "clojure": {"clj", "cljc", "cljs", "cljx"}, "cmake": {"cmake"}, "coffeescript": {"coffee"}, "config": {"cfg", "conf", "config", "ini"}, "coq": {"v"}, "cpp": {"C", "cc", "cpp", "cxx", "c++", "h", "hh", "hpp", "hxx", "h++", "inl"}, "crystal": {"cr", "ecr"}, "cs": {"cs"}, "csharp": {"cs"}, "cshtml": {"cshtml"}, "css": {"css", "scss", "sass", "less"}, "csv": {"csv"}, "cuda": {"cu", "cuh"}, "cython": {"pxd", "pxi", "pyx"}, "d": {"d"}, "dart": {"dart"}, "devicetree": {"dts", "dtsi"}, "dhall": {"dhall"}, "diff": {"diff", "patch"}, "docker": {"dockerfile"}, "go": {"go"}, "groovy": {"gradle", "groovy"}, "haskell": {"c2hs", "cpphs", "hs", "hsc", "lhs"}, "html": {"ejs", "htm", "html"}, "java": {"java", "jsp", "jspx", "properties"}, "js": {"cjs", "js", "jsx", "mjs", "vue"}, "json": {"json", "sarif"}, "jsonl": {"jsonl"}, "julia": {"jl"}, "jupyter": {"ipynb", "jpynb"}, "kotlin": {"kt", "kts"}, "less": {"less"}, "lua": {"lua"}, "make": {"mak", "mk"}, "markdown": {"markdown", "md", "mdown", "mdwn", "mdx", "mkd", "mkdn"}, "md": {"markdown", "md", "mdown", "mdwn", "mdx", "mkd", "mkdn"}, "matlab": {"m"}, "ocaml": {"ml", "mli", "mll", "mly"}, "perl": {"PL", "perl", "pl", "plh", "plx", "pm", "t"}, "php": {"php", "php3", "php4", "php5", "php7", "php8", "pht", "phtml"}, "python": {"py", "pyi"}, "py": {"py", "pyi"}, "ruby": {"gemspec", "rb", "rbw"}, "rust": {"rs"}, "sass": {"sass", "scss"}, "scala": {"sbt", "scala"}, "sh": {"bash", "sh", "zsh"}, "sql": {"psql", "sql"}, "swift": {"swift"}, "toml": {"toml"}, "ts": {"cts", "mts", "ts", "tsx"}, "typescript": {"cts", "mts", "ts", "tsx"}, "txt": {"txt"}, "vue": {"vue"}, "xml": {"dtd", "xml", "xsd", "xsl", "xslt"}, "yaml": {"yaml", "yml"}, "zig": {"zig"}, } if exts, ok := typeMap[fileType]; ok { for _, e := range exts { if ext == e { return true } } } return ext == fileType } // applyContext adds context lines around matches. func (b *InMemoryBackend) applyContext(matches []GrepMatch, req *GrepRequest) []GrepMatch { if len(matches) == 0 { return matches } beforeLines := 0 afterLines := 0 if req.BeforeLines > 0 { beforeLines = req.BeforeLines } if req.AfterLines > 0 { afterLines = req.AfterLines } if beforeLines <= 0 && afterLines <= 0 { return matches } // Group matches by file path for efficient processing matchesByFile := make(map[string][]GrepMatch) fileOrder := make([]string, 0) seenFiles := make(map[string]bool) for _, match := range matches { if !seenFiles[match.Path] { fileOrder = append(fileOrder, match.Path) seenFiles[match.Path] = true } matchesByFile[match.Path] = append(matchesByFile[match.Path], match) } var result []GrepMatch // Process each file once for _, filePath := range fileOrder { fileMatches := matchesByFile[filePath] // Get file content once per file b.mu.RLock() entry, exists := b.files[filePath] b.mu.RUnlock() if !exists { // If file doesn't exist, keep original matches result = append(result, fileMatches...) continue } lines := strings.Split(entry.content, "\n") processedLines := make(map[int]bool) // Process all matches for this file for _, match := range fileMatches { startLine := match.Line - beforeLines if startLine < 1 { startLine = 1 } endLine := match.Line + afterLines if endLine > len(lines) { endLine = len(lines) } for lineNum := startLine; lineNum <= endLine; lineNum++ { if !processedLines[lineNum] { processedLines[lineNum] = true result = append(result, GrepMatch{ Path: filePath, Line: lineNum, Content: lines[lineNum-1], }) } } } } return result } // GlobInfo returns file info entries matching the glob pattern. func (b *InMemoryBackend) GlobInfo(ctx context.Context, req *GlobInfoRequest) ([]FileInfo, error) { b.mu.RLock() defer b.mu.RUnlock() basePath := normalizePath(req.Path) isAbsolutePattern := strings.HasPrefix(req.Pattern, "/") var result []FileInfo for filePath, entry := range b.files { normalizedFilePath := normalizePath(filePath) var matchPath string var resultPath string if isAbsolutePattern { matchPath = normalizedFilePath resultPath = normalizedFilePath } else { if basePath != "/" && !strings.HasPrefix(normalizedFilePath, basePath+"/") && normalizedFilePath != basePath { continue } if basePath == "/" { matchPath = strings.TrimPrefix(normalizedFilePath, "/") } else { matchPath = strings.TrimPrefix(normalizedFilePath, basePath+"/") } resultPath = matchPath } matched, err := doublestar.Match(req.Pattern, matchPath) if err != nil { return nil, fmt.Errorf("invalid glob pattern: %w", err) } if matched { result = append(result, FileInfo{ Path: resultPath, IsDir: false, Size: int64(len(entry.content)), ModifiedAt: entry.modifiedAt.Format(time.RFC3339Nano), }) } } return result, nil } // Write creates or overwrites file content. func (b *InMemoryBackend) Write(ctx context.Context, req *WriteRequest) error { b.mu.Lock() defer b.mu.Unlock() filePath := normalizePath(req.FilePath) b.files[filePath] = &fileEntry{ content: req.Content, modifiedAt: time.Now(), } return nil } // Edit replaces string occurrences in a file. func (b *InMemoryBackend) Edit(ctx context.Context, req *EditRequest) error { b.mu.Lock() defer b.mu.Unlock() filePath := normalizePath(req.FilePath) entry, exists := b.files[filePath] if !exists { return fmt.Errorf("file not found: %s", filePath) } if req.OldString == "" { return fmt.Errorf("oldString must be non-empty") } content := entry.content if !strings.Contains(content, req.OldString) { return fmt.Errorf("oldString not found in file: %s", filePath) } if !req.ReplaceAll { firstIndex := strings.Index(content, req.OldString) if firstIndex != -1 { // Check if there's another occurrence after the first one if strings.Contains(content[firstIndex+len(req.OldString):], req.OldString) { return fmt.Errorf("multiple occurrences of oldString found in file %s, but ReplaceAll is false", filePath) } } } var newContent string if req.ReplaceAll { newContent = strings.ReplaceAll(content, req.OldString, req.NewString) } else { newContent = strings.Replace(content, req.OldString, req.NewString, 1) } b.files[filePath] = &fileEntry{ content: newContent, modifiedAt: time.Now(), } return nil } // normalizePath normalizes a file path by ensuring it starts with "/" and removing trailing slashes. func normalizePath(path string) string { if path == "" { return "/" } // Ensure path starts with "/" if !strings.HasPrefix(path, "/") { path = "/" + path } return filepath.Clean(path) } type grepCollector struct { allMatches []GrepMatch } func newGrepCollector() *grepCollector { return &grepCollector{ allMatches: []GrepMatch{}, } } func (c *grepCollector) processFile(filePath, content string, re *regexp.Regexp, req *GrepRequest) { fileMatches := c.findMatches(filePath, content, re, req) if len(fileMatches) > 0 { c.allMatches = append(c.allMatches, fileMatches...) } } func (c *grepCollector) findMatches(filePath, content string, re *regexp.Regexp, req *GrepRequest) []GrepMatch { if req.EnableMultiline { return c.findMultilineMatches(filePath, content, re) } return c.findSingleLineMatches(filePath, content, re) } func (c *grepCollector) findMultilineMatches(filePath, content string, re *regexp.Regexp) []GrepMatch { var fileMatches []GrepMatch matches := re.FindAllStringIndex(content, -1) lines := strings.Split(content, "\n") for _, match := range matches { matchStart := match[0] matchEnd := match[1] startLineNum := 1 + strings.Count(content[:matchStart], "\n") endLineNum := 1 + strings.Count(content[:matchEnd], "\n") for lineNum := startLineNum; lineNum <= endLineNum && lineNum <= len(lines); lineNum++ { fileMatches = append(fileMatches, GrepMatch{ Path: filePath, Line: lineNum, Content: lines[lineNum-1], }) } } return fileMatches } func (c *grepCollector) findSingleLineMatches(filePath, content string, re *regexp.Regexp) []GrepMatch { var fileMatches []GrepMatch lines := strings.Split(content, "\n") for lineNum, line := range lines { if re.MatchString(line) { fileMatches = append(fileMatches, GrepMatch{ Path: filePath, Line: lineNum + 1, Content: line, }) } } return fileMatches } func (c *grepCollector) buildResults(b *InMemoryBackend, req *GrepRequest) ([]GrepMatch, error) { return c.buildContentResult(b, req), nil } func (c *grepCollector) buildContentResult(b *InMemoryBackend, req *GrepRequest) []GrepMatch { results := c.allMatches if req.BeforeLines > 0 || req.AfterLines > 0 { results = b.applyContext(c.allMatches, req) } return results } ================================================ FILE: adk/filesystem/backend_inmemory_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filesystem import ( "context" "fmt" "path/filepath" "strings" "testing" "time" ) func TestInMemoryBackend_WriteAndRead(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() // Test Write err := backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "line1\nline2\nline3\nline4\nline5", }) if err != nil { t.Fatalf("Write failed: %v", err) } // Test Read - full content content, err := backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Limit: 100, }) if err != nil { t.Fatalf("Read failed: %v", err) } expected := "line1\nline2\nline3\nline4\nline5" if content.Content != expected { t.Errorf("Read content mismatch. Expected: %q, Got: %q", expected, content.Content) } // Test Read - with offset and limit content, err = backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Offset: 1, Limit: 2, }) if err != nil { t.Fatalf("Read with offset failed: %v", err) } expected = "line1\nline2" if content.Content != expected { t.Errorf("Read with offset content mismatch. Expected: %q, Got: %q", expected, content.Content) } // Test Read - non-existent file _, err = backend.Read(ctx, &ReadRequest{ FilePath: "/nonexistent.txt", Limit: 10, }) if err == nil { t.Error("Expected error for non-existent file, got nil") } } func TestInMemoryBackend_LsInfo(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() // Create some files backend.Write(ctx, &WriteRequest{ FilePath: "/file1.txt", Content: "content1", }) backend.Write(ctx, &WriteRequest{ FilePath: "/file2.txt", Content: "content2", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file3.txt", Content: "content3", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/subdir/file4.txt", Content: "content4", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir2/file5.txt", Content: "content5", }) // Test LsInfo - root infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 4 { // file1.txt, file2.txt, dir1, dir2 t.Errorf("Expected 4 items in root, got %d", len(infos)) } // Test LsInfo - specific directory infos, err = backend.LsInfo(ctx, &LsInfoRequest{Path: "/dir1"}) if err != nil { t.Fatalf("LsInfo for /dir1 failed: %v", err) } if len(infos) != 2 { // file3.txt, subdir t.Errorf("Expected 2 items in /dir1, got %d", len(infos)) } } func TestInMemoryBackend_Edit(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() // Create a file backend.Write(ctx, &WriteRequest{ FilePath: "/edit.txt", Content: "hello world\nhello again\nhello world", }) // Test Edit - report error if old string occurs err := backend.Edit(ctx, &EditRequest{ FilePath: "/edit.txt", OldString: "hello", NewString: "hi", ReplaceAll: false, }) if err == nil { t.Fatal("should have failed") } // Test Edit - replace all occurrences backend.Write(ctx, &WriteRequest{ FilePath: "/edit2.txt", Content: "hello world\nhello again\nhello world", }) err = backend.Edit(ctx, &EditRequest{ FilePath: "/edit2.txt", OldString: "hello", NewString: "hi", ReplaceAll: true, }) if err != nil { t.Fatalf("Edit (replace all) failed: %v", err) } content, _ := backend.Read(ctx, &ReadRequest{ FilePath: "/edit2.txt", Limit: 100, }) expected := "hi world\nhi again\nhi world" if content.Content != expected { t.Errorf("Edit (replace all) content mismatch. Expected: %q, Got: %q", expected, content.Content) } // Test Edit - non-existent file err = backend.Edit(ctx, &EditRequest{ FilePath: "/nonexistent.txt", OldString: "old", NewString: "new", ReplaceAll: false, }) if err == nil { t.Error("Expected error for non-existent file, got nil") } // Test Edit - empty oldString err = backend.Edit(ctx, &EditRequest{ FilePath: "/edit.txt", OldString: "", NewString: "new", ReplaceAll: false, }) if err == nil { t.Error("Expected error for empty oldString, got nil") } } func TestInMemoryBackend_LsInfo_PathIsFilename(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/file1.txt", Content: "content1", }) backend.Write(ctx, &WriteRequest{ FilePath: "/file2.txt", Content: "content2", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file3.txt", Content: "content3", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/subdir/file4.txt", Content: "content4", }) t.Run("RootDirectory", func(t *testing.T) { infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } for _, info := range infos { if strings.Contains(info.Path, "/") { t.Errorf("Path should be filename only, got: %s", info.Path) } if info.IsDir { if info.Path != "dir1" { t.Errorf("Expected directory name 'dir1', got: %s", info.Path) } } else { if info.Path != "file1.txt" && info.Path != "file2.txt" { t.Errorf("Expected filename 'file1.txt' or 'file2.txt', got: %s", info.Path) } } } }) t.Run("Subdirectory", func(t *testing.T) { infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/dir1"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } for _, info := range infos { if strings.Contains(info.Path, "/") { t.Errorf("Path should be filename only, got: %s", info.Path) } if info.IsDir { if info.Path != "subdir" { t.Errorf("Expected directory name 'subdir', got: %s", info.Path) } } else { if info.Path != "file3.txt" { t.Errorf("Expected filename 'file3.txt', got: %s", info.Path) } } } }) t.Run("NestedSubdirectory", func(t *testing.T) { infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/dir1/subdir"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos)) } info := infos[0] if info.Path != "file4.txt" { t.Errorf("Expected filename 'file4.txt', got: %s", info.Path) } if strings.Contains(info.Path, "/") { t.Errorf("Path should be filename only, got: %s", info.Path) } }) } func TestInMemoryBackend_GlobInfo(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() // Create some files backend.Write(ctx, &WriteRequest{ FilePath: "/file1.txt", Content: "content1", }) backend.Write(ctx, &WriteRequest{ FilePath: "/file2.py", Content: "content2", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file3.txt", Content: "content3", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file4.py", Content: "content4", }) // Test GlobInfo - match .txt files in root only infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "*.txt", Path: "/", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 1 { // only file1.txt in root t.Errorf("Expected 1 .txt file in root, got %d", len(infos)) } if infos[0].Path != "file1.txt" { t.Errorf("Expected relative path 'file1.txt', got %s", infos[0].Path) } // Test GlobInfo - match all .py files in dir1 infos, err = backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "*.py", Path: "/dir1", }) if err != nil { t.Fatalf("GlobInfo for /dir1 failed: %v", err) } if len(infos) != 1 { // file4.py t.Errorf("Expected 1 .py file in /dir1, got %d", len(infos)) } if infos[0].Path != "file4.py" { t.Errorf("Expected relative path 'file4.py', got %s", infos[0].Path) } } func TestInMemoryBackend_GlobInfo_RelativePath(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/Users/bytedance/Desktop/github/eino/file1.go", Content: "content1", }) backend.Write(ctx, &WriteRequest{ FilePath: "/Users/bytedance/Desktop/github/openai-go/paginationmanual_test.go", Content: "content2", }) backend.Write(ctx, &WriteRequest{ FilePath: "/Users/bytedance/Desktop/github/openai-go/paginationauto_test.go", Content: "content3", }) backend.Write(ctx, &WriteRequest{ FilePath: "/Users/bytedance/Desktop/other/test.go", Content: "content4", }) t.Run("GlobFromRootWithPattern", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "**/*.go", Path: "/Users/bytedance/Desktop/github", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 3 { t.Fatalf("Expected 3 .go files, got %d", len(infos)) } expectedPaths := map[string]bool{ "eino/file1.go": false, "openai-go/paginationmanual_test.go": false, "openai-go/paginationauto_test.go": false, } for _, info := range infos { if _, exists := expectedPaths[info.Path]; exists { expectedPaths[info.Path] = true } else { t.Errorf("Unexpected path: %s", info.Path) } } for path, found := range expectedPaths { if !found { t.Errorf("Expected path not found: %s", path) } } }) t.Run("GlobFromSubdirectory", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "*.go", Path: "/Users/bytedance/Desktop/github/openai-go", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 2 { t.Fatalf("Expected 2 .go files, got %d", len(infos)) } expectedPaths := map[string]bool{ "paginationmanual_test.go": false, "paginationauto_test.go": false, } for _, info := range infos { if _, exists := expectedPaths[info.Path]; exists { expectedPaths[info.Path] = true } else { t.Errorf("Unexpected path: %s", info.Path) } } for path, found := range expectedPaths { if !found { t.Errorf("Expected path not found: %s", path) } } }) t.Run("GlobFromRootWithAbsolutePattern", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "/Users/bytedance/Desktop/github/**/*.go", Path: "/", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } expected := map[string]bool{ "/Users/bytedance/Desktop/github/eino/file1.go": false, "/Users/bytedance/Desktop/github/openai-go/paginationmanual_test.go": false, "/Users/bytedance/Desktop/github/openai-go/paginationauto_test.go": false, } for _, info := range infos { if _, ok := expected[info.Path]; ok { expected[info.Path] = true } } for path, found := range expected { if !found { t.Errorf("Expected absolute path not found: %s", path) } } }) t.Run("GlobRecursiveWithRelativePattern", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "**/*.go", Path: "/Users/bytedance/Desktop/github", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 3 { t.Fatalf("Expected 3 .go files with ** pattern, got %d", len(infos)) } expected := map[string]bool{ "eino/file1.go": false, "openai-go/paginationmanual_test.go": false, "openai-go/paginationauto_test.go": false, } for _, info := range infos { if _, ok := expected[info.Path]; ok { expected[info.Path] = true } else { t.Errorf("Unexpected path: %s", info.Path) } } for path, found := range expected { if !found { t.Errorf("Expected relative path not found: %s", path) } } }) } func TestInMemoryBackend_GlobInfo_RecursivePattern(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/project/src/main.go", Content: "main", }) backend.Write(ctx, &WriteRequest{ FilePath: "/project/src/utils/helper.go", Content: "helper", }) backend.Write(ctx, &WriteRequest{ FilePath: "/project/src/utils/deep/nested.go", Content: "nested", }) backend.Write(ctx, &WriteRequest{ FilePath: "/project/test/test.go", Content: "test", }) backend.Write(ctx, &WriteRequest{ FilePath: "/project/README.md", Content: "readme", }) t.Run("DoubleStarMatchesAllSubdirectories", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "**/*.go", Path: "/project", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 4 { t.Fatalf("Expected 4 .go files, got %d", len(infos)) } expected := map[string]bool{ "src/main.go": false, "src/utils/helper.go": false, "src/utils/deep/nested.go": false, "test/test.go": false, } for _, info := range infos { if _, ok := expected[info.Path]; ok { expected[info.Path] = true } else { t.Errorf("Unexpected path: %s", info.Path) } } for path, found := range expected { if !found { t.Errorf("Expected path not found: %s", path) } } }) t.Run("DoubleStarInMiddleOfPattern", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "src/**/*.go", Path: "/project", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 3 { t.Fatalf("Expected 3 .go files under src/, got %d", len(infos)) } expected := map[string]bool{ "src/main.go": false, "src/utils/helper.go": false, "src/utils/deep/nested.go": false, } for _, info := range infos { if _, ok := expected[info.Path]; ok { expected[info.Path] = true } else { t.Errorf("Unexpected path: %s", info.Path) } } for path, found := range expected { if !found { t.Errorf("Expected path not found: %s", path) } } }) t.Run("DoubleStarAtEnd", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "src/**", Path: "/project", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 3 { t.Fatalf("Expected 3 files under src/, got %d", len(infos)) } expected := map[string]bool{ "src/main.go": false, "src/utils/helper.go": false, "src/utils/deep/nested.go": false, } for _, info := range infos { if _, ok := expected[info.Path]; ok { expected[info.Path] = true } } for path, found := range expected { if !found { t.Errorf("Expected path not found: %s", path) } } }) t.Run("AbsolutePatternWithDoubleStarRecursive", func(t *testing.T) { infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "/project/**/*.go", Path: "/", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 4 { t.Fatalf("Expected 4 .go files, got %d", len(infos)) } expected := map[string]bool{ "/project/src/main.go": false, "/project/src/utils/helper.go": false, "/project/src/utils/deep/nested.go": false, "/project/test/test.go": false, } for _, info := range infos { if _, ok := expected[info.Path]; ok { expected[info.Path] = true } else { t.Errorf("Unexpected path: %s", info.Path) } } for path, found := range expected { if !found { t.Errorf("Expected absolute path not found: %s", path) } } }) } func TestInMemoryBackend_Concurrent(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() // Test concurrent writes and reads done := make(chan bool) for i := 0; i < 10; i++ { go func(n int) { backend.Write(ctx, &WriteRequest{ FilePath: "/concurrent.txt", Content: "content", }) backend.Read(ctx, &ReadRequest{ FilePath: "/concurrent.txt", Limit: 10, }) done <- true }(i) } for i := 0; i < 10; i++ { <-done } } func TestInMemoryBackend_LsInfo_FileInfoMetadata(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() t.Run("FileMetadata", func(t *testing.T) { content := "hello world" err := backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: content, }) if err != nil { t.Fatalf("Write failed: %v", err) } infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos)) } info := infos[0] if info.Path != "test.txt" { t.Errorf("Expected path test.txt, got %s", info.Path) } if info.IsDir { t.Error("Expected IsDir to be false for file") } if info.Size != int64(len(content)) { t.Errorf("Expected size %d, got %d", len(content), info.Size) } if info.ModifiedAt == "" { t.Error("Expected ModifiedAt to be non-empty") } _, err = time.Parse(time.RFC3339Nano, info.ModifiedAt) if err != nil { t.Errorf("ModifiedAt is not valid RFC3339 format: %v", err) } }) t.Run("DirectoryMetadata", func(t *testing.T) { backend := NewInMemoryBackend() err := backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file1.txt", Content: "content1", }) if err != nil { t.Fatalf("Write failed: %v", err) } infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 1 { t.Fatalf("Expected 1 directory, got %d", len(infos)) } info := infos[0] if info.Path != "dir1" { t.Errorf("Expected path dir1, got %s", info.Path) } if !info.IsDir { t.Error("Expected IsDir to be true for directory") } if info.Size != 0 { t.Errorf("Expected size 0 for directory, got %d", info.Size) } if info.ModifiedAt == "" { t.Error("Expected ModifiedAt to be non-empty for directory") } }) t.Run("MixedFilesAndDirectories", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{ FilePath: "/file1.txt", Content: "content1", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file2.txt", Content: "content2", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/subdir/file3.txt", Content: "content3", }) infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 2 { t.Fatalf("Expected 2 items (file1.txt, dir1), got %d", len(infos)) } fileCount := 0 dirCount := 0 for _, info := range infos { if info.IsDir { dirCount++ if info.Path != "dir1" { t.Errorf("Expected directory path dir1, got %s", info.Path) } } else { fileCount++ if info.Path != "file1.txt" { t.Errorf("Expected file path file1.txt, got %s", info.Path) } if info.Size != int64(len("content1")) { t.Errorf("Expected file size %d, got %d", len("content1"), info.Size) } } } if fileCount != 1 { t.Errorf("Expected 1 file, got %d", fileCount) } if dirCount != 1 { t.Errorf("Expected 1 directory, got %d", dirCount) } }) t.Run("SubdirectoryListing", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file1.txt", Content: "short", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/subdir/file2.txt", Content: "longer content here", }) infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/dir1"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 2 { t.Fatalf("Expected 2 items (file1.txt, subdir), got %d", len(infos)) } for _, info := range infos { if info.Path == "file1.txt" { if info.IsDir { t.Error("Expected file1.txt to be a file") } if info.Size != int64(len("short")) { t.Errorf("Expected size %d, got %d", len("short"), info.Size) } } else if info.Path == "subdir" { if !info.IsDir { t.Error("Expected subdir to be a directory") } } else { t.Errorf("Unexpected path: %s", info.Path) } } }) t.Run("DirectoryModifiedAtUsesLatestFile", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file1.txt", Content: "content1", }) time.Sleep(10 * time.Millisecond) backend.Write(ctx, &WriteRequest{ FilePath: "/dir1/file2.txt", Content: "content2", }) infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 1 { t.Fatalf("Expected 1 directory, got %d", len(infos)) } dirInfo := infos[0] if !dirInfo.IsDir { t.Fatal("Expected directory") } dirModTime, _ := time.Parse(time.RFC3339Nano, dirInfo.ModifiedAt) subInfos, _ := backend.LsInfo(ctx, &LsInfoRequest{Path: "/dir1"}) var latestFileTime time.Time for _, info := range subInfos { fileTime, _ := time.Parse(time.RFC3339Nano, info.ModifiedAt) if fileTime.After(latestFileTime) { latestFileTime = fileTime } } if !dirModTime.Equal(latestFileTime) && dirModTime.Before(latestFileTime) { t.Logf("Directory mod time: %v, Latest file time: %v", dirModTime, latestFileTime) } }) } func TestInMemoryBackend_GlobInfo_FileInfoMetadata(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() t.Run("BasicMetadata", func(t *testing.T) { content := "test content" backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: content, }) infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "*.txt", Path: "/", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos)) } info := infos[0] if info.Path != "test.txt" { t.Errorf("Expected path test.txt, got %s", info.Path) } if info.IsDir { t.Error("Expected IsDir to be false") } if info.Size != int64(len(content)) { t.Errorf("Expected size %d, got %d", len(content), info.Size) } if info.ModifiedAt == "" { t.Error("Expected ModifiedAt to be non-empty") } }) t.Run("MultipleFilesMetadata", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{ FilePath: "/file1.txt", Content: "short", }) backend.Write(ctx, &WriteRequest{ FilePath: "/file2.txt", Content: "much longer content", }) backend.Write(ctx, &WriteRequest{ FilePath: "/file3.py", Content: "python", }) infos, err := backend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: "*.txt", Path: "/", }) if err != nil { t.Fatalf("GlobInfo failed: %v", err) } if len(infos) != 2 { t.Fatalf("Expected 2 .txt files, got %d", len(infos)) } for _, info := range infos { if info.IsDir { t.Errorf("Expected IsDir to be false for %s", info.Path) } if info.Size <= 0 { t.Errorf("Expected positive size for %s, got %d", info.Path, info.Size) } if info.ModifiedAt == "" { t.Errorf("Expected ModifiedAt to be non-empty for %s", info.Path) } } }) } func TestInMemoryBackend_WriteAndEdit_ModifiedAt(t *testing.T) { ctx := context.Background() t.Run("WriteUpdatesModifiedAt", func(t *testing.T) { backend := NewInMemoryBackend() beforeWrite := time.Now() time.Sleep(1 * time.Millisecond) err := backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "initial content", }) if err != nil { t.Fatalf("Write failed: %v", err) } time.Sleep(1 * time.Millisecond) afterWrite := time.Now() infos, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos)) } modTime, err := time.Parse(time.RFC3339Nano, infos[0].ModifiedAt) if err != nil { t.Fatalf("Failed to parse ModifiedAt: %v", err) } if modTime.Before(beforeWrite) || modTime.After(afterWrite) { t.Errorf("ModifiedAt %v should be between %v and %v", modTime, beforeWrite, afterWrite) } }) t.Run("EditUpdatesModifiedAt", func(t *testing.T) { backend := NewInMemoryBackend() err := backend.Write(ctx, &WriteRequest{ FilePath: "/edit.txt", Content: "hello world", }) if err != nil { t.Fatalf("Write failed: %v", err) } infos1, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos1) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos1)) } modTime1, err := time.Parse(time.RFC3339Nano, infos1[0].ModifiedAt) if err != nil { t.Fatalf("Failed to parse ModifiedAt: %v", err) } time.Sleep(10 * time.Millisecond) err = backend.Edit(ctx, &EditRequest{ FilePath: "/edit.txt", OldString: "hello", NewString: "hi", ReplaceAll: true, }) if err != nil { t.Fatalf("Edit failed: %v", err) } infos2, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos2) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos2)) } modTime2, err := time.Parse(time.RFC3339Nano, infos2[0].ModifiedAt) if err != nil { t.Fatalf("Failed to parse ModifiedAt: %v", err) } if !modTime2.After(modTime1) { t.Errorf("ModifiedAt should be updated after edit. Before: %v, After: %v", modTime1, modTime2) } }) t.Run("OverwriteUpdatesModifiedAt", func(t *testing.T) { backend := NewInMemoryBackend() err := backend.Write(ctx, &WriteRequest{ FilePath: "/overwrite.txt", Content: "original", }) if err != nil { t.Fatalf("Write failed: %v", err) } infos1, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos1) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos1)) } modTime1, err := time.Parse(time.RFC3339Nano, infos1[0].ModifiedAt) if err != nil { t.Fatalf("Failed to parse ModifiedAt: %v", err) } time.Sleep(10 * time.Millisecond) err = backend.Write(ctx, &WriteRequest{ FilePath: "/overwrite.txt", Content: "new content", }) if err != nil { t.Fatalf("Write failed: %v", err) } infos2, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos2) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos2)) } modTime2, err := time.Parse(time.RFC3339Nano, infos2[0].ModifiedAt) if err != nil { t.Fatalf("Failed to parse ModifiedAt: %v", err) } if !modTime2.After(modTime1) { t.Errorf("ModifiedAt should be updated after overwrite. Before: %v, After: %v", modTime1, modTime2) } }) t.Run("SizeUpdatesAfterEdit", func(t *testing.T) { backend := NewInMemoryBackend() err := backend.Write(ctx, &WriteRequest{ FilePath: "/size.txt", Content: "hello", }) if err != nil { t.Fatalf("Write failed: %v", err) } infos1, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos1) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos1)) } size1 := infos1[0].Size err = backend.Edit(ctx, &EditRequest{ FilePath: "/size.txt", OldString: "hello", NewString: "hello world", ReplaceAll: true, }) if err != nil { t.Fatalf("Edit failed: %v", err) } infos2, err := backend.LsInfo(ctx, &LsInfoRequest{Path: "/"}) if err != nil { t.Fatalf("LsInfo failed: %v", err) } if len(infos2) != 1 { t.Fatalf("Expected 1 file, got %d", len(infos2)) } size2 := infos2[0].Size if size2 <= size1 { t.Errorf("Size should increase after edit. Before: %d, After: %d", size1, size2) } if size2 != int64(len("hello world")) { t.Errorf("Expected size %d, got %d", len("hello world"), size2) } }) } func TestInMemoryBackend_Read_EdgeCases(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "line1\nline2\nline3", }) t.Run("negative offset should be treated as zero", func(t *testing.T) { content, err := backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Offset: -5, Limit: 2, }) if err != nil { t.Fatalf("Read failed: %v", err) } expected := "line1\nline2" if content.Content != expected { t.Errorf("Expected: %q, Got: %q", expected, content.Content) } }) t.Run("offset exceeds file length", func(t *testing.T) { content, err := backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Offset: 100, Limit: 10, }) if err != nil { t.Fatalf("Read failed: %v", err) } if content.Content != "" { t.Errorf("Expected empty content, got: %q", content.Content) } }) t.Run("zero or negative limit should use default 200", func(t *testing.T) { content, err := backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Offset: 0, Limit: 0, }) if err != nil { t.Fatalf("Read failed: %v", err) } lines := strings.Split(content.Content, "\n") if len(lines) != 3 { t.Errorf("Expected 3 lines, got %d", len(lines)) } }) t.Run("limit exceeds remaining lines", func(t *testing.T) { content, err := backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Offset: 1, Limit: 100, }) if err != nil { t.Fatalf("Read failed: %v", err) } lines := strings.Split(content.Content, "\n") if len(lines) != 3 { t.Errorf("Expected 3 lines, got %d", len(lines)) } }) } func TestInMemoryBackend_Edit_EdgeCases(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() t.Run("edit non-existent file", func(t *testing.T) { err := backend.Edit(ctx, &EditRequest{ FilePath: "/nonexistent.txt", OldString: "old", NewString: "new", }) if err == nil { t.Error("Expected error for non-existent file") } if !strings.Contains(err.Error(), "not found") { t.Errorf("Expected 'not found' error, got: %v", err) } }) t.Run("empty oldString", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "content", }) err := backend.Edit(ctx, &EditRequest{ FilePath: "/test.txt", OldString: "", NewString: "new", }) if err == nil { t.Error("Expected error for empty oldString") } if !strings.Contains(err.Error(), "non-empty") { t.Errorf("Expected 'non-empty' error, got: %v", err) } }) t.Run("oldString not found", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "hello world", }) err := backend.Edit(ctx, &EditRequest{ FilePath: "/test.txt", OldString: "notfound", NewString: "new", }) if err == nil { t.Error("Expected error when oldString not found") } if !strings.Contains(err.Error(), "not found in file") { t.Errorf("Expected 'not found in file' error, got: %v", err) } }) t.Run("multiple occurrences with ReplaceAll false", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "foo bar foo baz", }) err := backend.Edit(ctx, &EditRequest{ FilePath: "/test.txt", OldString: "foo", NewString: "FOO", ReplaceAll: false, }) if err == nil { t.Error("Expected error for multiple occurrences with ReplaceAll=false") } if !strings.Contains(err.Error(), "multiple occurrences") { t.Errorf("Expected 'multiple occurrences' error, got: %v", err) } }) t.Run("single occurrence with ReplaceAll false", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "foo bar baz", }) err := backend.Edit(ctx, &EditRequest{ FilePath: "/test.txt", OldString: "foo", NewString: "FOO", ReplaceAll: false, }) if err != nil { t.Fatalf("Edit failed: %v", err) } content, _ := backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Limit: 100, }) if !strings.Contains(content.Content, "FOO") { t.Error("Expected content to contain 'FOO'") } }) t.Run("ReplaceAll replaces all occurrences", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/test.txt", Content: "foo bar foo baz foo", }) err := backend.Edit(ctx, &EditRequest{ FilePath: "/test.txt", OldString: "foo", NewString: "FOO", ReplaceAll: true, }) if err != nil { t.Fatalf("Edit failed: %v", err) } content, _ := backend.Read(ctx, &ReadRequest{ FilePath: "/test.txt", Limit: 100, }) if strings.Contains(content.Content, "foo") { t.Error("Expected all 'foo' to be replaced") } fooCount := strings.Count(content.Content, "FOO") if fooCount != 3 { t.Errorf("Expected 3 occurrences of 'FOO', got %d", fooCount) } }) } func TestInMemoryBackend_NormalizePath(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() t.Run("paths are normalized on write", func(t *testing.T) { testCases := []struct { inputPath string normalizedPath string }{ {"test.txt", "/test.txt"}, {"/test.txt", "/test.txt"}, {"//test.txt", "/test.txt"}, {"/dir//file.txt", "/dir/file.txt"}, {"/dir/../file.txt", "/file.txt"}, } for _, tc := range testCases { backend.Write(ctx, &WriteRequest{ FilePath: tc.inputPath, Content: "content", }) content, err := backend.Read(ctx, &ReadRequest{ FilePath: tc.normalizedPath, Limit: 10, }) if err != nil { t.Errorf("Failed to read normalized path %s (from %s): %v", tc.normalizedPath, tc.inputPath, err) } if !strings.Contains(content.Content, "content") { t.Errorf("Content not found for normalized path %s (from %s)", tc.normalizedPath, tc.inputPath) } } }) } func TestInMemoryBackend_MatchFileType(t *testing.T) { testCases := []struct { ext string fileType string expected bool }{ {"go", "go", true}, {"py", "python", true}, {"py", "py", true}, {"js", "js", true}, {"ts", "typescript", true}, {"ts", "ts", true}, {"cpp", "cpp", true}, {"c", "c", true}, {"h", "c", true}, {"md", "markdown", true}, {"txt", "txt", true}, {"go", "python", false}, {"js", "typescript", false}, {"unknown", "go", false}, } for _, tc := range testCases { t.Run(fmt.Sprintf("%s matches %s", tc.ext, tc.fileType), func(t *testing.T) { result := matchFileType(tc.ext, tc.fileType) if result != tc.expected { t.Errorf("matchFileType(%q, %q) = %v, expected %v", tc.ext, tc.fileType, result, tc.expected) } }) } } func TestInMemoryBackend_GrepRaw(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/test.go", Content: "package main\nfunc main() {\n\tlog.Error(\"error\")\n\tfmt.Println(\"hello\")\n}", }) backend.Write(ctx, &WriteRequest{ FilePath: "/test.py", Content: "def hello():\n print('error')\n print('world')", }) backend.Write(ctx, &WriteRequest{ FilePath: "/dir/file.go", Content: "package test\nfunc TestError() {\n\tlog.Error(\"test error\")\n}", }) t.Run("basic pattern search", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "error", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 3 { t.Errorf("Expected 2 matches, got %d", len(matches)) } }) t.Run("empty pattern error", func(t *testing.T) { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "", }) if err == nil { t.Error("Expected error for empty pattern") } if !strings.Contains(err.Error(), "cannot be empty") { t.Errorf("Expected 'cannot be empty' error, got: %v", err) } }) t.Run("invalid regex pattern", func(t *testing.T) { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "[invalid", }) if err == nil { t.Error("Expected error for invalid regex") } if !strings.Contains(err.Error(), "invalid regex") { t.Errorf("Expected 'invalid regex' error, got: %v", err) } }) t.Run("case sensitive search", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "Error", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 3 { t.Errorf("Expected 2 matches, got %d", len(matches)) } }) t.Run("case insensitive search", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "ERROR", CaseInsensitive: true, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 3 { t.Errorf("Expected at least 2 matches, got %d", len(matches)) } }) t.Run("filter by file type", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "error", FileType: "go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if !strings.HasSuffix(match.Path, ".go") { t.Errorf("Expected only .go files, got: %s", match.Path) } } }) t.Run("filter by glob pattern", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "Error", Glob: "*.go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if !strings.HasSuffix(match.Path, ".go") { t.Errorf("Expected only .go files, got: %s", match.Path) } } }) t.Run("invalid glob pattern", func(t *testing.T) { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "error", Glob: "[invalid", }) if err == nil { t.Error("Expected error for invalid glob pattern") } if !strings.Contains(err.Error(), "invalid glob") { t.Errorf("Expected 'invalid glob' error, got: %v", err) } }) t.Run("search in specific path", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "Error", Path: "/dir", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if !strings.HasPrefix(match.Path, "/dir") { t.Errorf("Expected matches only from /dir, got: %s", match.Path) } } }) t.Run("search with non-existent path", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "error", Path: "/nonexistent", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 0 { t.Errorf("Expected 0 matches for non-existent path, got %d", len(matches)) } }) t.Run("regex pattern matching", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "log\\..*Error", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 1 { t.Errorf("Expected at least 1 match, got %d", len(matches)) } }) t.Run("no matches found", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "nonexistent_pattern_xyz", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 0 { t.Errorf("Expected 0 matches, got %d", len(matches)) } }) t.Run("match line numbers", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "log\\.Error", FileType: "go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if match.Line <= 0 { t.Errorf("Expected positive line number, got %d", match.Line) } } }) t.Run("match content is returned", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "package main", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 1 { t.Fatal("Expected at least 1 match") } found := false for _, match := range matches { if strings.Contains(match.Content, "package main") { found = true break } } if !found { t.Error("Expected match content to contain 'package main'") } }) } func TestInMemoryBackend_GrepRaw_WithContext(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/context.txt", Content: "line1\nline2\ntarget line\nline4\nline5\nline6", }) t.Run("with before context", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "target", BeforeLines: 2, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 3 { t.Errorf("Expected at least 3 matches (2 before + target), got %d", len(matches)) } }) t.Run("with after context", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "target", AfterLines: 2, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 3 { t.Errorf("Expected at least 3 matches (target + 2 after), got %d", len(matches)) } }) t.Run("with both before and after context", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "target", BeforeLines: 1, AfterLines: 1, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 3 { t.Errorf("Expected at least 3 matches (1 before + target + 1 after), got %d", len(matches)) } }) t.Run("context at file boundaries", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/boundary.txt", Content: "first line target\nsecond line", }) matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "target", Path: "/boundary.txt", BeforeLines: 5, AfterLines: 5, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) == 0 { t.Error("Expected at least 1 match") } }) t.Run("zero context lines", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "target", BeforeLines: 0, AfterLines: 0, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 1 { t.Error("Expected at least 1 match") } }) t.Run("negative context lines treated as zero", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "target", BeforeLines: -5, AfterLines: -5, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) < 1 { t.Error("Expected at least 1 match") } }) } func TestInMemoryBackend_GrepRaw_Multiline(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/multiline.txt", Content: "start\nmiddle line\nend", }) t.Run("single line mode (default)", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "start.*end", EnableMultiline: false, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 0 { t.Errorf("Expected 0 matches in single-line mode, got %d", len(matches)) } }) t.Run("multiline mode enabled", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "start[\\s\\S]*end", EnableMultiline: true, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) == 0 { t.Error("Expected matches in multiline mode") } }) t.Run("multiline with multiple matches", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/multiline2.txt", Content: "block1 start\nblock1 middle\nblock1 end\n\nblock2 start\nblock2 end", }) matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "start[\\s\\S]*?end", Path: "/multiline2.txt", EnableMultiline: true, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) == 0 { t.Error("Expected matches in multiline mode") } }) t.Run("multiline with multiple matches v2", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/multiline3.txt", Content: ` const a = 1; function calculateTotal( items, discount ) { return items.reduce((sum, item) => sum + item.price, 0); } const b = 2; /* * This is a comment * spanning multiple lines */ class UserService { constructor(db) { this.db = db; } } `, }) matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "function calculateTotal\\([^\\)]*\\)", Path: "/multiline3.txt", EnableMultiline: true, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) == 0 { t.Error("Expected matches in multiline mode") } foundLastLine := false for _, match := range matches { if match.Line == 6 && strings.Contains(match.Content, ") {") { foundLastLine = true break } } if !foundLastLine { t.Error("Expected to find line 5 with ') {' in content") for _, match := range matches { t.Logf("Line %d: %s", match.Line, match.Content) } } }) } func TestInMemoryBackend_GrepRaw_EmptyFiles(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() t.Run("search in empty file", func(t *testing.T) { backend.Write(ctx, &WriteRequest{ FilePath: "/empty.txt", Content: "", }) matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "anything", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 0 { t.Errorf("Expected 0 matches in empty file, got %d", len(matches)) } }) t.Run("search with no files", func(t *testing.T) { emptyBackend := NewInMemoryBackend() matches, err := emptyBackend.GrepRaw(ctx, &GrepRequest{ Pattern: "anything", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 0 { t.Errorf("Expected 0 matches with no files, got %d", len(matches)) } }) } func TestInMemoryBackend_GrepRaw_SpecialCharacters(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/special.txt", Content: "interface{}\nmap[string]int\nfunc() error\n$variable\n*pointer", }) t.Run("match curly braces", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "interface\\{\\}", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 1 { t.Errorf("Expected 1 match, got %d", len(matches)) } }) t.Run("match square brackets", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "map\\[.*\\]", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 1 { t.Errorf("Expected 1 match, got %d", len(matches)) } }) t.Run("match parentheses", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "func\\(\\)", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 1 { t.Errorf("Expected 1 match, got %d", len(matches)) } }) t.Run("match dollar sign", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "\\$variable", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 1 { t.Errorf("Expected 1 match, got %d", len(matches)) } }) t.Run("match asterisk", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "\\*pointer", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 1 { t.Errorf("Expected 1 match, got %d", len(matches)) } }) } func TestInMemoryBackend_GrepRaw_Concurrent(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() for i := 0; i < 10; i++ { backend.Write(ctx, &WriteRequest{ FilePath: fmt.Sprintf("/file%d.txt", i), Content: fmt.Sprintf("content%d with error message", i), }) } t.Run("concurrent grep operations", func(t *testing.T) { done := make(chan bool) for i := 0; i < 10; i++ { go func() { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "error", }) if err != nil { t.Errorf("Concurrent GrepRaw failed: %v", err) } done <- true }() } for i := 0; i < 10; i++ { <-done } }) t.Run("parallel file processing", func(t *testing.T) { backend := NewInMemoryBackend() for i := 0; i < 100; i++ { backend.Write(ctx, &WriteRequest{ FilePath: fmt.Sprintf("/large/file%d.go", i), Content: fmt.Sprintf("package main\nimport \"log\"\nfunc test%d() {\n\tlog.Error(\"error %d\")\n}", i, i), }) } matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "log\\.Error", FileType: "go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 100 { t.Errorf("Expected 100 matches, got %d", len(matches)) } }) t.Run("single file no parallelism", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{ FilePath: "/single.txt", Content: "error line 1\nerror line 2\nerror line 3", }) matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "error", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 3 { t.Errorf("Expected 3 matches, got %d", len(matches)) } }) t.Run("empty files list", func(t *testing.T) { backend := NewInMemoryBackend() matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "anything", Path: "/nonexistent", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) != 0 { t.Errorf("Expected 0 matches, got %d", len(matches)) } }) t.Run("concurrent operations are safe", func(t *testing.T) { backend := NewInMemoryBackend() for i := 0; i < 20; i++ { backend.Write(ctx, &WriteRequest{ FilePath: fmt.Sprintf("/concurrent/file%d.txt", i), Content: fmt.Sprintf("line1\nline2\npattern%d\nline4", i), }) } done := make(chan error, 5) for i := 0; i < 5; i++ { go func(id int) { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "pattern\\d+", }) done <- err }(i) } for i := 0; i < 5; i++ { if err := <-done; err != nil { t.Errorf("Concurrent operation %d failed: %v", i, err) } } }) } func BenchmarkInMemoryBackend_GrepRaw(b *testing.B) { backend := NewInMemoryBackend() ctx := context.Background() for i := 0; i < 100; i++ { content := fmt.Sprintf(`package main import ( "fmt" "log" ) func process%d() error { log.Error("processing error %d") fmt.Println("hello world") return nil } func calculate%d(x, y int) int { return x + y } `, i, i, i) backend.Write(ctx, &WriteRequest{ FilePath: fmt.Sprintf("/project/src/file%d.go", i), Content: content, }) } b.Run("parallel_grep", func(b *testing.B) { for i := 0; i < b.N; i++ { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "log\\.Error", FileType: "go", }) if err != nil { b.Fatalf("GrepRaw failed: %v", err) } } }) b.Run("with_glob_filter", func(b *testing.B) { for i := 0; i < b.N; i++ { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "Error", Glob: "**/*.go", }) if err != nil { b.Fatalf("GrepRaw failed: %v", err) } } }) b.Run("case_insensitive", func(b *testing.B) { for i := 0; i < b.N; i++ { _, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "ERROR", CaseInsensitive: true, }) if err != nil { b.Fatalf("GrepRaw failed: %v", err) } } }) } func TestInMemoryBackend_GrepRaw_ComplexScenarios(t *testing.T) { backend := NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &WriteRequest{ FilePath: "/project/src/main.go", Content: "package main\nimport \"log\"\nfunc main() {\n\tlog.Error(\"error\")\n}", }) backend.Write(ctx, &WriteRequest{ FilePath: "/project/src/utils/helper.go", Content: "package utils\nfunc Helper() error {\n\treturn nil\n}", }) backend.Write(ctx, &WriteRequest{ FilePath: "/project/test/main_test.go", Content: "package main\nimport \"testing\"\nfunc TestMain(t *testing.T) {\n}", }) t.Run("combine path and file type filters", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "package", Path: "/project/src", FileType: "go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if !strings.HasPrefix(match.Path, "/project/src") { t.Errorf("Expected path to start with /project/src, got: %s", match.Path) } if !strings.HasSuffix(match.Path, ".go") { t.Errorf("Expected .go file, got: %s", match.Path) } } }) t.Run("complex regex with case insensitive", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "func\\s+\\w+", CaseInsensitive: true, }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } if len(matches) == 0 { t.Error("Expected at least 1 match for function declarations") } }) t.Run("glob with directory structure", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "package", Glob: "*_test.go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if !strings.HasSuffix(match.Path, "_test.go") { t.Errorf("Expected test file, got: %s", match.Path) } } }) t.Run("glob with recursive pattern", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "package", Glob: "**/*.go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if !strings.HasSuffix(match.Path, ".go") { t.Errorf("Expected .go file, got: %s", match.Path) } } if len(matches) == 0 { t.Error("Expected at least 1 match for **/*.go pattern") } }) t.Run("glob with path prefix", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "package", Glob: "src/**/*.go", Path: "/project", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if !strings.HasPrefix(match.Path, "/project/src") { t.Errorf("Expected path to start with /project/src, got: %s", match.Path) } if !strings.HasSuffix(match.Path, ".go") { t.Errorf("Expected .go file, got: %s", match.Path) } } }) t.Run("glob simple filename pattern", func(t *testing.T) { matches, err := backend.GrepRaw(ctx, &GrepRequest{ Pattern: "package", Glob: "main.go", }) if err != nil { t.Fatalf("GrepRaw failed: %v", err) } for _, match := range matches { if filepath.Base(match.Path) != "main.go" { t.Errorf("Expected filename 'main.go', got: %s", match.Path) } } }) } func TestInMemoryBackend_Read_Scenarios(t *testing.T) { ctx := context.Background() t.Run("empty file returns empty content", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/empty.txt", Content: ""}) content, err := backend.Read(ctx, &ReadRequest{FilePath: "/empty.txt"}) if err != nil { t.Fatalf("unexpected error: %v", err) } if content.Content != "" { t.Errorf("expected empty content, got %q", content.Content) } }) t.Run("single-line file without newline", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/single.txt", Content: "hello"}) content, err := backend.Read(ctx, &ReadRequest{FilePath: "/single.txt"}) if err != nil { t.Fatalf("unexpected error: %v", err) } if content.Content != "hello" { t.Errorf("expected %q, got %q", "hello", content.Content) } }) t.Run("offset 0 and offset 1 both start from first line", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/f.txt", Content: "a\nb\nc"}) c0, _ := backend.Read(ctx, &ReadRequest{FilePath: "/f.txt", Offset: 0, Limit: 1}) c1, _ := backend.Read(ctx, &ReadRequest{FilePath: "/f.txt", Offset: 1, Limit: 1}) if c0.Content != c1.Content { t.Errorf("Offset=0 (%q) and Offset=1 (%q) should return the same first line", c0.Content, c1.Content) } if c0.Content != "a" { t.Errorf("expected first line %q, got %q", "a", c0.Content) } }) t.Run("file with trailing newline preserves trailing empty line", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/trail.txt", Content: "line1\nline2\n"}) content, err := backend.Read(ctx, &ReadRequest{FilePath: "/trail.txt"}) if err != nil { t.Fatalf("unexpected error: %v", err) } if content.Content != "line1\nline2\n" { t.Errorf("expected %q, got %q", "line1\nline2\n", content.Content) } lines := strings.Split(content.Content, "\n") if len(lines) != 3 { // ["line1", "line2", ""] t.Errorf("expected 3 elements from split, got %d", len(lines)) } }) t.Run("offset exactly at last line", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/f.txt", Content: "a\nb\nc"}) // Offset=3 (1-based) → last line "c" content, err := backend.Read(ctx, &ReadRequest{FilePath: "/f.txt", Offset: 3, Limit: 10}) if err != nil { t.Fatalf("unexpected error: %v", err) } if content.Content != "c" { t.Errorf("expected %q, got %q", "c", content.Content) } }) t.Run("offset one beyond last line returns empty", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/f.txt", Content: "a\nb\nc"}) content, err := backend.Read(ctx, &ReadRequest{FilePath: "/f.txt", Offset: 4, Limit: 10}) if err != nil { t.Fatalf("unexpected error: %v", err) } if content.Content != "" { t.Errorf("expected empty content, got %q", content.Content) } }) t.Run("limit=1 reads exactly one line", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/f.txt", Content: "a\nb\nc"}) for i, expected := range []string{"a", "b", "c"} { content, err := backend.Read(ctx, &ReadRequest{FilePath: "/f.txt", Offset: i + 1, Limit: 1}) if err != nil { t.Fatalf("line %d: unexpected error: %v", i+1, err) } if content.Content != expected { t.Errorf("line %d: expected %q, got %q", i+1, expected, content.Content) } } }) t.Run("sliding window reads consecutive ranges correctly", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/f.txt", Content: "l1\nl2\nl3\nl4\nl5"}) tests := []struct { offset int limit int expected string }{ {1, 2, "l1\nl2"}, {2, 2, "l2\nl3"}, {3, 2, "l3\nl4"}, {4, 2, "l4\nl5"}, {5, 2, "l5"}, } for _, tt := range tests { content, err := backend.Read(ctx, &ReadRequest{FilePath: "/f.txt", Offset: tt.offset, Limit: tt.limit}) if err != nil { t.Fatalf("offset=%d limit=%d: unexpected error: %v", tt.offset, tt.limit, err) } if content.Content != tt.expected { t.Errorf("offset=%d limit=%d: expected %q, got %q", tt.offset, tt.limit, tt.expected, content.Content) } } }) t.Run("file with only newlines", func(t *testing.T) { backend := NewInMemoryBackend() backend.Write(ctx, &WriteRequest{FilePath: "/newlines.txt", Content: "\n\n\n"}) content, err := backend.Read(ctx, &ReadRequest{FilePath: "/newlines.txt", Offset: 2, Limit: 1}) if err != nil { t.Fatalf("unexpected error: %v", err) } // Line 2 is an empty string between two newlines if content.Content != "" { t.Errorf("expected empty line content, got %q", content.Content) } }) } ================================================ FILE: adk/flow.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "fmt" "log" "runtime/debug" "strings" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/compose" icb "github.com/cloudwego/eino/internal/callbacks" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) type HistoryEntry struct { IsUserInput bool AgentName string Message Message } type HistoryRewriter func(ctx context.Context, entries []*HistoryEntry) ([]Message, error) type flowAgent struct { Agent subAgents []*flowAgent parentAgent *flowAgent disallowTransferToParent bool historyRewriter HistoryRewriter checkPointStore compose.CheckPointStore } func (a *flowAgent) deepCopy() *flowAgent { ret := &flowAgent{ Agent: a.Agent, subAgents: make([]*flowAgent, 0, len(a.subAgents)), parentAgent: a.parentAgent, disallowTransferToParent: a.disallowTransferToParent, historyRewriter: a.historyRewriter, checkPointStore: a.checkPointStore, } for _, sa := range a.subAgents { ret.subAgents = append(ret.subAgents, sa.deepCopy()) } return ret } // SetSubAgents sets sub-agents for the given agent and returns the updated agent. func SetSubAgents(ctx context.Context, agent Agent, subAgents []Agent) (ResumableAgent, error) { return setSubAgents(ctx, agent, subAgents) } type AgentOption func(options *flowAgent) // WithDisallowTransferToParent prevents a sub-agent from transferring to its parent. func WithDisallowTransferToParent() AgentOption { return func(fa *flowAgent) { fa.disallowTransferToParent = true } } // WithHistoryRewriter sets a rewriter to transform conversation history. func WithHistoryRewriter(h HistoryRewriter) AgentOption { return func(fa *flowAgent) { fa.historyRewriter = h } } func toFlowAgent(ctx context.Context, agent Agent, opts ...AgentOption) *flowAgent { var fa *flowAgent var ok bool if fa, ok = agent.(*flowAgent); !ok { fa = &flowAgent{Agent: agent} } else { fa = fa.deepCopy() } for _, opt := range opts { opt(fa) } if fa.historyRewriter == nil { fa.historyRewriter = buildDefaultHistoryRewriter(agent.Name(ctx)) } return fa } // AgentWithOptions wraps an agent with flow-specific options and returns it. func AgentWithOptions(ctx context.Context, agent Agent, opts ...AgentOption) Agent { return toFlowAgent(ctx, agent, opts...) } func setSubAgents(ctx context.Context, agent Agent, subAgents []Agent) (*flowAgent, error) { fa := toFlowAgent(ctx, agent) if len(fa.subAgents) > 0 { return nil, errors.New("agent's sub-agents has already been set") } if onAgent, ok_ := fa.Agent.(OnSubAgents); ok_ { err := onAgent.OnSetSubAgents(ctx, subAgents) if err != nil { return nil, err } } for _, s := range subAgents { fsa := toFlowAgent(ctx, s) if fsa.parentAgent != nil { return nil, errors.New("agent has already been set as a sub-agent of another agent") } fsa.parentAgent = fa if onAgent, ok__ := fsa.Agent.(OnSubAgents); ok__ { err := onAgent.OnSetAsSubAgent(ctx, agent) if err != nil { return nil, err } if fsa.disallowTransferToParent { err = onAgent.OnDisallowTransferToParent(ctx) if err != nil { return nil, err } } } fa.subAgents = append(fa.subAgents, fsa) } return fa, nil } func (a *flowAgent) getAgent(ctx context.Context, name string) *flowAgent { for _, subAgent := range a.subAgents { if subAgent.Name(ctx) == name { return subAgent } } if a.parentAgent != nil && a.parentAgent.Name(ctx) == name { return a.parentAgent } return nil } func rewriteMessage(msg Message, agentName string) Message { var sb strings.Builder sb.WriteString("For context:") if msg.Role == schema.Assistant { if msg.Content != "" { sb.WriteString(fmt.Sprintf(" [%s] said: %s.", agentName, msg.Content)) } if len(msg.ToolCalls) > 0 { for i := range msg.ToolCalls { f := msg.ToolCalls[i].Function sb.WriteString(fmt.Sprintf(" [%s] called tool: `%s` with arguments: %s.", agentName, f.Name, f.Arguments)) } } } else if msg.Role == schema.Tool && msg.Content != "" { sb.WriteString(fmt.Sprintf(" [%s] `%s` tool returned result: %s.", agentName, msg.ToolName, msg.Content)) } rewritten := schema.UserMessage(sb.String()) if msg.MultiContent != nil { rewritten.MultiContent = append([]schema.ChatMessagePart{}, msg.MultiContent...) } if msg.UserInputMultiContent != nil { rewritten.UserInputMultiContent = append([]schema.MessageInputPart{}, msg.UserInputMultiContent...) } // Convert AssistantGenMultiContent to UserInputMultiContent, since the role changes to User. // Reasoning parts have no user input equivalent and are dropped. for _, part := range msg.AssistantGenMultiContent { switch part.Type { case schema.ChatMessagePartTypeText: rewritten.UserInputMultiContent = append(rewritten.UserInputMultiContent, schema.MessageInputPart{ Type: part.Type, Text: part.Text, Extra: part.Extra, }) case schema.ChatMessagePartTypeImageURL: if part.Image != nil { rewritten.UserInputMultiContent = append(rewritten.UserInputMultiContent, schema.MessageInputPart{ Type: part.Type, Image: &schema.MessageInputImage{MessagePartCommon: part.Image.MessagePartCommon}, Extra: part.Extra, }) } case schema.ChatMessagePartTypeAudioURL: if part.Audio != nil { rewritten.UserInputMultiContent = append(rewritten.UserInputMultiContent, schema.MessageInputPart{ Type: part.Type, Audio: &schema.MessageInputAudio{MessagePartCommon: part.Audio.MessagePartCommon}, Extra: part.Extra, }) } case schema.ChatMessagePartTypeVideoURL: if part.Video != nil { rewritten.UserInputMultiContent = append(rewritten.UserInputMultiContent, schema.MessageInputPart{ Type: part.Type, Video: &schema.MessageInputVideo{MessagePartCommon: part.Video.MessagePartCommon}, Extra: part.Extra, }) } } } return rewritten } func genMsg(entry *HistoryEntry, agentName string) (Message, error) { msg := entry.Message if entry.AgentName != agentName { msg = rewriteMessage(msg, entry.AgentName) } return msg, nil } func (ai *AgentInput) deepCopy() *AgentInput { copied := &AgentInput{ Messages: make([]Message, len(ai.Messages)), EnableStreaming: ai.EnableStreaming, } copy(copied.Messages, ai.Messages) return copied } func (a *flowAgent) genAgentInput(ctx context.Context, runCtx *runContext, skipTransferMessages bool) (*AgentInput, error) { input := runCtx.RootInput.deepCopy() events := runCtx.Session.getEvents() historyEntries := make([]*HistoryEntry, 0) for _, m := range input.Messages { historyEntries = append(historyEntries, &HistoryEntry{ IsUserInput: true, Message: m, }) } for _, event := range events { if skipTransferMessages && event.Action != nil && event.Action.TransferToAgent != nil { // If skipTransferMessages is true and the event contain transfer action, the message in this event won't be appended to history entries. if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.Role == schema.Tool && len(historyEntries) > 0 { // If the skipped message's role is Tool, remove the previous history entry as it's also a transfer message(from ChatModelAgent and GenTransferMessages). historyEntries = historyEntries[:len(historyEntries)-1] } continue } msg, err := getMessageFromWrappedEvent(event) if err != nil { var retryErr *WillRetryError if errors.As(err, &retryErr) { log.Printf("failed to get message from event, but will retry: %v", err) continue } return nil, err } if msg == nil { continue } historyEntries = append(historyEntries, &HistoryEntry{ AgentName: event.AgentName, Message: msg, }) } messages, err := a.historyRewriter(ctx, historyEntries) if err != nil { return nil, err } input.Messages = messages return input, nil } func buildDefaultHistoryRewriter(agentName string) HistoryRewriter { return func(ctx context.Context, entries []*HistoryEntry) ([]Message, error) { messages := make([]Message, 0, len(entries)) var err error for _, entry := range entries { msg := entry.Message if !entry.IsUserInput { msg, err = genMsg(entry, agentName) if err != nil { return nil, fmt.Errorf("gen agent input failed: %w", err) } } if msg != nil { messages = append(messages, msg) } } return messages, nil } } func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { agentName := a.Name(ctx) var runCtx *runContext ctx, runCtx = initRunCtx(ctx, agentName, input) ctx = AppendAddressSegment(ctx, AddressSegmentAgent, agentName) o := getCommonOptions(nil, opts...) processedInput, err := a.genAgentInput(ctx, runCtx, o.skipTransferMessages) if err != nil { cbInput := &AgentCallbackInput{Input: input} ctx = callbacks.OnStart(ctx, cbInput) return wrapIterWithOnEnd(ctx, genErrorIter(err)) } ctxForSubAgents := ctx agentType := getAgentType(a.Agent) ctx = initAgentCallbacks(ctx, agentName, agentType, filterOptions(agentName, opts)...) cbInput := &AgentCallbackInput{Input: processedInput} ctx = callbacks.OnStart(ctx, cbInput) input = processedInput if wf, ok := a.Agent.(*workflowAgent); ok { return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)) } aIter := a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go a.run(ctx, ctxForSubAgents, runCtx, aIter, generator, opts...) return iterator } func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { agentName := a.Name(ctx) ctx, info = buildResumeInfo(ctx, agentName, info) ctxForSubAgents := ctx agentType := getAgentType(a.Agent) ctx = initAgentCallbacks(ctx, agentName, agentType, filterOptions(agentName, opts)...) cbInput := &AgentCallbackInput{ResumeInfo: info} ctx = callbacks.OnStart(ctx, cbInput) if info.WasInterrupted { ra, ok := a.Agent.(ResumableAgent) if !ok { return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ "but is not a ResumableAgent", agentName))) } iterator, generator := NewAsyncIteratorPair[*AgentEvent]() if _, ok := ra.(*workflowAgent); ok { filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) aIter := ra.Resume(ctx, info, filteredOpts...) return wrapIterWithOnEnd(ctx, aIter) } aIter := ra.Resume(ctx, info, opts...) go a.run(ctx, ctxForSubAgents, getRunCtx(ctxForSubAgents), aIter, generator, opts...) return iterator } nextAgentName, err := getNextResumeAgent(ctx, info) if err != nil { return wrapIterWithOnEnd(ctx, genErrorIter(err)) } subAgent := a.getAgent(ctxForSubAgents, nextAgentName) if subAgent == nil { // the inner agent wrapped by flowAgent may be ANY agent, including flowAgent, // AgentWithDeterministicTransferTo, or any other custom agent user defined, // or any combinations of the above in any order, // that ultimately wraps the flowAgent with sub-agents // We need to go through these wrappers to reach the flowAgent with sub-agents. if len(a.subAgents) == 0 { if ra, ok := a.Agent.(ResumableAgent); ok { // Use ctx (callback-enriched) instead of ctxForSubAgents here. // This is the inner agent that flowAgent wraps (e.g., supervisorContainer), // not a sub-agent. The callback context from OnStart should be propagated // to ensure unified tracing for container patterns. return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)) } } return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' not found from flowAgent '%s'", nextAgentName, agentName))) } return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)) } type DeterministicTransferConfig struct { Agent Agent ToAgentNames []string } func (a *flowAgent) run( ctx context.Context, ctxForSubAgents context.Context, runCtx *runContext, aIter *AsyncIterator[*AgentEvent], generator *AsyncGenerator[*AgentEvent], opts ...AgentRunOption) { cbIter, cbGen := NewAsyncIteratorPair[*AgentEvent]() cbOutput := &AgentCallbackOutput{Events: cbIter} icb.On(ctx, cbOutput, icb.BuildOnEndHandleWithCopy(copyAgentCallbackOutput), callbacks.TimingOnEnd, false) defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&AgentEvent{Err: e}) } cbGen.Close() generator.Close() }() var lastAction *AgentAction for { event, ok := aIter.Next() if !ok { break } // RunPath ownership: the eino framework sets RunPath exactly once. // If event.RunPath is already set (e.g., by agentTool), we don't modify it. // If event.RunPath is nil/empty, we set it to the current runCtx.RunPath. // This ensures RunPath is set exactly once and not duplicated. if len(event.RunPath) == 0 { event.AgentName = a.Name(ctx) event.RunPath = runCtx.RunPath } // Recording policy: exact RunPath match (non-interrupt) indicates events belonging to this agent execution. // This prevents parent recording of child/tool-internal emissions. if (event.Action == nil || event.Action.Interrupted == nil) && exactRunPathMatch(runCtx.RunPath, event.RunPath) { // copy the event so that the copied event's stream is exclusive for any potential consumer // copy before adding to session because once added to session it's stream could be consumed by genAgentInput at any time // interrupt action are not added to session, because ALL information contained in it // is either presented to end-user, or made available to agents through other means copied := copyAgentEvent(event) setAutomaticClose(copied) setAutomaticClose(event) runCtx.Session.addEvent(copied) } // Action gating uses exact run-path match as well: // only actions originating from this agent execution (not child/tool runs) // should influence parent control flow (exit/transfer/interrupt). if exactRunPathMatch(runCtx.RunPath, event.RunPath) { lastAction = event.Action } copied := copyAgentEvent(event) setAutomaticClose(copied) setAutomaticClose(event) cbGen.Send(copied) generator.Send(event) } var destName string if lastAction != nil { if lastAction.Interrupted != nil { return } if lastAction.Exit { return } if lastAction.TransferToAgent != nil { destName = lastAction.TransferToAgent.DestAgentName } } // handle transferring to another agent if destName != "" { agentToRun := a.getAgent(ctxForSubAgents, destName) if agentToRun == nil { e := fmt.Errorf("transfer failed: agent '%s' not found when transferring from '%s'", destName, a.Name(ctxForSubAgents)) generator.Send(&AgentEvent{Err: e}) return } subAIter := agentToRun.Run(ctxForSubAgents, nil /*subagents get input from runCtx*/, opts...) for { subEvent, ok_ := subAIter.Next() if !ok_ { break } setAutomaticClose(subEvent) generator.Send(subEvent) } } } func exactRunPathMatch(aPath, bPath []RunStep) bool { if len(aPath) != len(bPath) { return false } for i := range aPath { if !aPath[i].Equals(bPath[i]) { return false } } return true } func wrapIterWithOnEnd(ctx context.Context, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] { cbIter, cbGen := NewAsyncIteratorPair[*AgentEvent]() cbOutput := &AgentCallbackOutput{Events: cbIter} icb.On(ctx, cbOutput, icb.BuildOnEndHandleWithCopy(copyAgentCallbackOutput), callbacks.TimingOnEnd, false) outIter, outGen := NewAsyncIteratorPair[*AgentEvent]() go func() { defer func() { cbGen.Close() outGen.Close() }() for { event, ok := iter.Next() if !ok { break } copied := copyAgentEvent(event) cbGen.Send(copied) outGen.Send(event) } }() return outIter } ================================================ FILE: adk/flow_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/callbacks" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) func strPtr(s string) *string { return &s } func TestRewriteMessage(t *testing.T) { imageCommon := schema.MessagePartCommon{URL: strPtr("http://img.example.com")} audioCommon := schema.MessagePartCommon{URL: strPtr("http://audio.example.com")} videoCommon := schema.MessagePartCommon{URL: strPtr("http://video.example.com")} msg := &schema.Message{ Role: schema.Assistant, Content: "hello", MultiContent: []schema.ChatMessagePart{ {Type: schema.ChatMessagePartTypeText, Text: "legacy"}, }, UserInputMultiContent: []schema.MessageInputPart{ {Type: schema.ChatMessagePartTypeText, Text: "pre-existing"}, }, AssistantGenMultiContent: []schema.MessageOutputPart{ {Type: schema.ChatMessagePartTypeText, Text: "gen-text", Extra: map[string]any{"k": "v"}}, {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: imageCommon}}, {Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageOutputAudio{MessagePartCommon: audioCommon}}, {Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: videoCommon}}, {Type: schema.ChatMessagePartTypeReasoning, Reasoning: &schema.MessageOutputReasoning{Text: "secret thoughts"}}, }, } rewritten := rewriteMessage(msg, "OtherAgent") assert.Equal(t, schema.User, rewritten.Role) // MultiContent: copied, not shared assert.Equal(t, msg.MultiContent, rewritten.MultiContent) rewritten.MultiContent[0].Text = "mutated" assert.Equal(t, "legacy", msg.MultiContent[0].Text) // UserInputMultiContent: pre-existing entry copied, AssistantGenMultiContent appended (reasoning dropped) assert.Len(t, rewritten.UserInputMultiContent, 5) // 1 pre-existing + 4 converted (text/image/audio/video) // pre-existing entry is not shared rewritten.UserInputMultiContent[0].Text = "mutated" assert.Equal(t, "pre-existing", msg.UserInputMultiContent[0].Text) // text conversion assert.Equal(t, schema.ChatMessagePartTypeText, rewritten.UserInputMultiContent[1].Type) assert.Equal(t, "gen-text", rewritten.UserInputMultiContent[1].Text) assert.Equal(t, map[string]any{"k": "v"}, rewritten.UserInputMultiContent[1].Extra) // image conversion assert.Equal(t, schema.ChatMessagePartTypeImageURL, rewritten.UserInputMultiContent[2].Type) assert.Equal(t, imageCommon, rewritten.UserInputMultiContent[2].Image.MessagePartCommon) // audio conversion assert.Equal(t, schema.ChatMessagePartTypeAudioURL, rewritten.UserInputMultiContent[3].Type) assert.Equal(t, audioCommon, rewritten.UserInputMultiContent[3].Audio.MessagePartCommon) // video conversion assert.Equal(t, schema.ChatMessagePartTypeVideoURL, rewritten.UserInputMultiContent[4].Type) assert.Equal(t, videoCommon, rewritten.UserInputMultiContent[4].Video.MessagePartCommon) // reasoning is dropped; AssistantGenMultiContent is not set on rewritten message assert.Empty(t, rewritten.AssistantGenMultiContent) } // TestTransferToAgent tests the TransferToAgent functionality func TestTransferToAgent(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create mock models for parent and child agents parentModel := mockModel.NewMockToolCallingChatModel(ctrl) childModel := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the parent model // First call: parent model generates a message with TransferToAgent tool call parentModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("I'll transfer this to the child agent", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "ChildAgent"}`, }, }, }), nil). Times(1) // Set up expectations for the child model // Second call: child model generates a response childModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Hello from child agent", nil), nil). Times(1) // Both models should implement WithTools parentModel.EXPECT().WithTools(gomock.Any()).Return(parentModel, nil).AnyTimes() childModel.EXPECT().WithTools(gomock.Any()).Return(childModel, nil).AnyTimes() // Create parent agent parentAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "ParentAgent", Description: "Parent agent that will transfer to child", Instruction: "You are a parent agent.", Model: parentModel, }) assert.NoError(t, err) assert.NotNil(t, parentAgent) // Create child agent childAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "ChildAgent", Description: "Child agent that handles specific tasks", Instruction: "You are a child agent.", Model: childModel, }) assert.NoError(t, err) assert.NotNil(t, childAgent) // Set up parent-child relationship flowAgent, err := SetSubAgents(ctx, parentAgent, []Agent{childAgent}) assert.NoError(t, err) assert.NotNil(t, flowAgent) assert.NotNil(t, parentAgent.subAgents) assert.NotNil(t, childAgent.parentAgent) // Run the parent agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Please transfer this to the child agent"), }, } ctx, _ = initRunCtx(ctx, flowAgent.Name(ctx), input) iterator := flowAgent.Run(ctx, input) assert.NotNil(t, iterator) // First event: parent model output with tool call event1, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event1) assert.Nil(t, event1.Err) assert.NotNil(t, event1.Output) assert.NotNil(t, event1.Output.MessageOutput) assert.Equal(t, schema.Assistant, event1.Output.MessageOutput.Role) // Second event: tool output (TransferToAgent) event2, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event2) assert.Nil(t, event2.Err) assert.NotNil(t, event2.Output) assert.NotNil(t, event2.Output.MessageOutput) assert.Equal(t, schema.Tool, event2.Output.MessageOutput.Role) // Verify the action is TransferToAgent assert.NotNil(t, event2.Action) assert.NotNil(t, event2.Action.TransferToAgent) assert.Equal(t, "ChildAgent", event2.Action.TransferToAgent.DestAgentName) // Third event: child model output event3, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event3) assert.Nil(t, event3.Err) assert.NotNil(t, event3.Output) assert.NotNil(t, event3.Output.MessageOutput) assert.Equal(t, schema.Assistant, event3.Output.MessageOutput.Role) // Verify the message content from child agent msg := event3.Output.MessageOutput.Message assert.NotNil(t, msg) assert.Equal(t, "Hello from child agent", msg.Content) // No more events _, ok = iterator.Next() assert.False(t, ok) } func TestTransferToAgentWithDesignatedCallback(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() parentModel := mockModel.NewMockToolCallingChatModel(ctrl) childModel := mockModel.NewMockToolCallingChatModel(ctrl) parentModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("I'll transfer this to the child agent", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: TransferToAgentToolName, Arguments: `{"agent_name": "ChildAgent"}`, }, }, }), nil). Times(1) childModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Hello from child agent", nil), nil). Times(1) parentModel.EXPECT().WithTools(gomock.Any()).Return(parentModel, nil).AnyTimes() childModel.EXPECT().WithTools(gomock.Any()).Return(childModel, nil).AnyTimes() parentAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "ParentAgent", Description: "Parent agent that will transfer to child", Instruction: "You are a parent agent.", Model: parentModel, }) assert.NoError(t, err) childAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "ChildAgent", Description: "Child agent that handles specific tasks", Instruction: "You are a child agent.", Model: childModel, }) assert.NoError(t, err) flowAgent, err := SetSubAgents(ctx, parentAgent, []Agent{childAgent}) assert.NoError(t, err) var childCallbackCount int var mu sync.Mutex handler := callbacks.NewHandlerBuilder().OnStartFn( func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component == ComponentOfAgent && info.Name == "ChildAgent" { mu.Lock() childCallbackCount++ mu.Unlock() } return ctx }).Build() input := &AgentInput{ Messages: []Message{ schema.UserMessage("Please transfer this to the child agent"), }, } ctx, _ = initRunCtx(ctx, flowAgent.Name(ctx), input) iterator := flowAgent.Run(ctx, input, WithCallbacks(handler).DesignateAgent("ChildAgent")) for { _, ok := iterator.Next() if !ok { break } } assert.Equal(t, 1, childCallbackCount, "designated callback for ChildAgent should fire exactly once during transfer") } ================================================ FILE: adk/handler.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "fmt" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // InvokableToolCallEndpoint is the function signature for invoking a tool synchronously. // Middleware authors implement wrappers around this endpoint to add custom behavior. type InvokableToolCallEndpoint func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) // StreamableToolCallEndpoint is the function signature for invoking a tool with streaming output. // Middleware authors implement wrappers around this endpoint to add custom behavior. type StreamableToolCallEndpoint func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) type EnhancedInvokableToolCallEndpoint func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) type EnhancedStreamableToolCallEndpoint func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) // ToolContext provides metadata about the tool being wrapped. type ToolContext struct { Name string CallID string } // ModelContext contains context information passed to WrapModel. type ModelContext struct { // Tools contains the current tool list configured for the agent. // This is populated at request time with the tools that will be sent to the model. Tools []*schema.ToolInfo // ModelRetryConfig contains the retry configuration for the model. // This is populated at request time from the agent's ModelRetryConfig. // Used by EventSenderModelWrapper to wrap stream errors appropriately. ModelRetryConfig *ModelRetryConfig } // ChatModelAgentContext contains runtime information passed to handlers before each ChatModelAgent run. // Handlers can modify Instruction, Tools, and ReturnDirectly to customize agent behavior. // // This type is specific to ChatModelAgent. Other agent types may define their own context types. type ChatModelAgentContext struct { // Instruction is the current instruction for the Agent execution. // It includes the instruction configured for the agent, additional instructions appended by framework // and AgentMiddleware, and modifications applied by previous BeforeAgent handlers. // The finalized instruction after all BeforeAgent handlers are then passed to GenModelInput, // to be (optionally) formatted with SessionValues and converted to system message. Instruction string // Tools are the raw tools (without any wrapper or tool middleware) currently configured for the Agent execution. // They includes tools passed in AgentConfig, implicit tools added by framework such as transfer / exit tools, // and other tools already added by middlewares. Tools []tool.BaseTool // ReturnDirectly is the set of tool names currently configured to cause the Agent to return directly. // This is based on the return directly map configured for the agent, plus any modifications // by previous BeforeAgent handlers. ReturnDirectly map[string]bool } // ChatModelAgentMiddleware defines the interface for customizing ChatModelAgent behavior. // // IMPORTANT: This interface is specifically designed for ChatModelAgent and agents built // on top of it (e.g., DeepAgent). // // Why ChatModelAgentMiddleware instead of AgentMiddleware? // // AgentMiddleware is a struct type, which has inherent limitations: // - Struct types are closed: users cannot add new methods to extend functionality // - The framework only recognizes AgentMiddleware's fixed fields, so even if users // embed AgentMiddleware in a custom struct and add methods, the framework cannot // call those methods (config.Middlewares is []AgentMiddleware, not a user type) // - Callbacks in AgentMiddleware only return error, cannot return modified context // // ChatModelAgentMiddleware is an interface type, which is open for extension: // - Users can implement custom handlers with arbitrary internal state and methods // - Hook methods return (context.Context, ..., error) for direct context propagation // - Wrapper methods (WrapToolCall, WrapModel) enable context propagation through the // wrapped endpoint chain: wrappers can pass modified context to the next wrapper // - Configuration is centralized in struct fields rather than scattered in closures // // ChatModelAgentMiddleware vs AgentMiddleware: // - Use AgentMiddleware for simple, static additions (extra instruction/tools) // - Use ChatModelAgentMiddleware for dynamic behavior, context modification, or call wrapping // - AgentMiddleware is kept for backward compatibility with existing users // - Both can be used together; see AgentMiddleware documentation for execution order // // Use *BaseChatModelAgentMiddleware as an embedded struct to provide default no-op // implementations for all methods. type ChatModelAgentMiddleware interface { // BeforeAgent is called before each agent run, allowing modification of // the agent's instruction and tools configuration. BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) // BeforeModelRewriteState is called before each model invocation. // The returned state is persisted to the agent's internal state and passed to the model. // The returned context is propagated to the model call and subsequent handlers. // // The ChatModelAgentState struct provides access to: // - Messages: the conversation history // // The ModelContext struct provides read-only access to: // - Tools: the current tool list that will be sent to the model BeforeModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) // AfterModelRewriteState is called after each model invocation. // The input state includes the model's response as the last message. // The returned state is persisted to the agent's internal state. // // The ChatModelAgentState struct provides access to: // - Messages: the conversation history including the model's response // // The ModelContext struct provides read-only access to: // - Tools: the current tool list that was sent to the model AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) // WrapInvokableToolCall wraps a tool's synchronous execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. // // This method is only called for tools that implement InvokableTool. // If a tool only implements StreamableTool, this method will not be called for that tool. // // This method is called at request time when the tool is about to be executed. // The tCtx parameter provides metadata about the tool: // - Name: The name of the tool being wrapped // - CallID: The unique identifier for this specific tool call WrapInvokableToolCall(ctx context.Context, endpoint InvokableToolCallEndpoint, tCtx *ToolContext) (InvokableToolCallEndpoint, error) // WrapStreamableToolCall wraps a tool's streaming execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. // // This method is only called for tools that implement StreamableTool. // If a tool only implements InvokableTool, this method will not be called for that tool. // // This method is called at request time when the tool is about to be executed. // The tCtx parameter provides metadata about the tool: // - Name: The name of the tool being wrapped // - CallID: The unique identifier for this specific tool call WrapStreamableToolCall(ctx context.Context, endpoint StreamableToolCallEndpoint, tCtx *ToolContext) (StreamableToolCallEndpoint, error) // WrapEnhancedInvokableToolCall wraps an enhanced tool's synchronous execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. // // This method is only called for tools that implement EnhancedInvokableTool. // If a tool only implements EnhancedStreamableTool, this method will not be called for that tool. // // This method is called at request time when the tool is about to be executed. // The tCtx parameter provides metadata about the tool: // - Name: The name of the tool being wrapped // - CallID: The unique identifier for this specific tool call WrapEnhancedInvokableToolCall(ctx context.Context, endpoint EnhancedInvokableToolCallEndpoint, tCtx *ToolContext) (EnhancedInvokableToolCallEndpoint, error) // WrapEnhancedStreamableToolCall wraps an enhanced tool's streaming execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. // // This method is only called for tools that implement EnhancedStreamableTool. // If a tool only implements EnhancedInvokableTool, this method will not be called for that tool. // // This method is called at request time when the tool is about to be executed. // The tCtx parameter provides metadata about the tool: // - Name: The name of the tool being wrapped // - CallID: The unique identifier for this specific tool call WrapEnhancedStreamableToolCall(ctx context.Context, endpoint EnhancedStreamableToolCallEndpoint, tCtx *ToolContext) (EnhancedStreamableToolCallEndpoint, error) // WrapModel wraps a chat model with custom behavior. // Return the input model unchanged and nil error if no wrapping is needed. // // This method is called at request time when the model is about to be invoked. // Note: The parameter is BaseChatModel (not ToolCallingChatModel) because wrappers // only need to intercept Generate/Stream calls. Tool binding (WithTools) is handled // separately by the framework and does not flow through user wrappers. // // The mc parameter contains the current tool configuration: // - Tools: The tool infos that will be sent to the model WrapModel(ctx context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) } // BaseChatModelAgentMiddleware provides default no-op implementations for ChatModelAgentMiddleware. // Embed *BaseChatModelAgentMiddleware in custom handlers to only override the methods you need. // // Example: // // type MyHandler struct { // *adk.BaseChatModelAgentMiddleware // // custom fields // } // // func (h *MyHandler) BeforeModelRewriteState(ctx context.Context, state *adk.ChatModelAgentState, mc *adk.ModelContext) (context.Context, *adk.ChatModelAgentState, error) { // // custom logic // return ctx, state, nil // } type BaseChatModelAgentMiddleware struct{} func (b *BaseChatModelAgentMiddleware) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) (InvokableToolCallEndpoint, error) { return endpoint, nil } func (b *BaseChatModelAgentMiddleware) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) (StreamableToolCallEndpoint, error) { return endpoint, nil } func (b *BaseChatModelAgentMiddleware) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { return endpoint, nil } func (b *BaseChatModelAgentMiddleware) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { return endpoint, nil } func (b *BaseChatModelAgentMiddleware) WrapModel(_ context.Context, m model.BaseChatModel, _ *ModelContext) (model.BaseChatModel, error) { return m, nil } func (b *BaseChatModelAgentMiddleware) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { return ctx, runCtx, nil } func (b *BaseChatModelAgentMiddleware) BeforeModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return ctx, state, nil } func (b *BaseChatModelAgentMiddleware) AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return ctx, state, nil } // SetRunLocalValue sets a key-value pair that persists for the duration of the current agent Run() invocation. // The value is scoped to this specific execution and is not shared across different Run() calls or agent instances. // // Values stored here are compatible with interrupt/resume cycles - they will be serialized and restored // when the agent is resumed. For custom types, you must register them using schema.RegisterName[T]() // in an init() function to ensure proper serialization. // // This function can only be called from within a ChatModelAgentMiddleware during agent execution. // Returns an error if called outside of an agent execution context. func SetRunLocalValue(ctx context.Context, key string, value any) error { err := compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.Extra == nil { st.Extra = make(map[string]any) } st.Extra[key] = value return nil }) if err != nil { return fmt.Errorf("SetRunLocalValue failed: must be called within a ChatModelAgent Run() or Resume() execution context: %w", err) } return nil } // GetRunLocalValue retrieves a value that was set during the current agent Run() invocation. // The value is scoped to this specific execution and is not shared across different Run() calls or agent instances. // // Values stored via SetRunLocalValue are compatible with interrupt/resume cycles - they will be serialized // and restored when the agent is resumed. For custom types, you must register them using schema.RegisterName[T]() // in an init() function to ensure proper serialization. // // This function can only be called from within a ChatModelAgentMiddleware during agent execution. // Returns the value and true if found, or nil and false if not found or if called outside of an agent execution context. func GetRunLocalValue(ctx context.Context, key string) (any, bool, error) { var val any var found bool err := compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.Extra != nil { val, found = st.Extra[key] } return nil }) if err != nil { return nil, false, fmt.Errorf("GetRunLocalValue failed: must be called within a ChatModelAgent Run() or Resume() execution context: %w", err) } return val, found, nil } // DeleteRunLocalValue removes a value that was set during the current agent Run() invocation. // // This function can only be called from within a ChatModelAgentMiddleware during agent execution. // Returns an error if called outside of an agent execution context. func DeleteRunLocalValue(ctx context.Context, key string) error { err := compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.Extra != nil { delete(st.Extra, key) } return nil }) if err != nil { return fmt.Errorf("DeleteRunLocalValue failed: must be called within a ChatModelAgent Run() or Resume() execution context: %w", err) } return nil } // SendEvent sends a custom AgentEvent to the event stream during agent execution. // This allows ChatModelAgentMiddleware implementations to emit custom events that will be // received by the caller iterating over the agent's event stream. // // This function can only be called from within a ChatModelAgentMiddleware during agent execution. // Returns an error if called outside of an agent execution context. func SendEvent(ctx context.Context, event *AgentEvent) error { execCtx := getChatModelAgentExecCtx(ctx) if execCtx == nil || execCtx.generator == nil { return fmt.Errorf("SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context") } execCtx.generator.Send(event) return nil } ================================================ FILE: adk/handler_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) type testInstructionHandler struct { *BaseChatModelAgentMiddleware text string } func (h *testInstructionHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { if runCtx.Instruction == "" { runCtx.Instruction = h.text } else if h.text != "" { runCtx.Instruction = runCtx.Instruction + "\n" + h.text } return ctx, runCtx, nil } type testInstructionFuncHandler struct { *BaseChatModelAgentMiddleware fn func(ctx context.Context, instruction string) (context.Context, string, error) } func (h *testInstructionFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { newCtx, newInstruction, err := h.fn(ctx, runCtx.Instruction) if err != nil { return ctx, runCtx, err } runCtx.Instruction = newInstruction return newCtx, runCtx, nil } type testToolsHandler struct { *BaseChatModelAgentMiddleware tools []tool.BaseTool } func (h *testToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { runCtx.Tools = append(runCtx.Tools, h.tools...) return ctx, runCtx, nil } type testToolsFuncHandler struct { *BaseChatModelAgentMiddleware fn func(ctx context.Context, tools []tool.BaseTool, returnDirectly map[string]bool) (context.Context, []tool.BaseTool, map[string]bool, error) } func (h *testToolsFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { newCtx, newTools, newReturnDirectly, err := h.fn(ctx, runCtx.Tools, runCtx.ReturnDirectly) if err != nil { return ctx, runCtx, err } runCtx.Tools = newTools runCtx.ReturnDirectly = newReturnDirectly return newCtx, runCtx, nil } type testBeforeAgentHandler struct { *BaseChatModelAgentMiddleware fn func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) } func (h *testBeforeAgentHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { return h.fn(ctx, runCtx) } type testBeforeModelRewriteStateHandler struct { *BaseChatModelAgentMiddleware fn func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) } func (h *testBeforeModelRewriteStateHandler) BeforeModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return h.fn(ctx, state, mc) } type testAfterModelRewriteStateHandler struct { *BaseChatModelAgentMiddleware fn func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) } func (h *testAfterModelRewriteStateHandler) AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return h.fn(ctx, state, mc) } type testToolWrapperHandler struct { *BaseChatModelAgentMiddleware wrapInvokableFn func(context.Context, InvokableToolCallEndpoint, *ToolContext) InvokableToolCallEndpoint wrapStreamableFn func(context.Context, StreamableToolCallEndpoint, *ToolContext) StreamableToolCallEndpoint } func (h *testToolWrapperHandler) WrapInvokableToolCall(ctx context.Context, endpoint InvokableToolCallEndpoint, tCtx *ToolContext) (InvokableToolCallEndpoint, error) { if h.wrapInvokableFn != nil { return h.wrapInvokableFn(ctx, endpoint, tCtx), nil } return endpoint, nil } func (h *testToolWrapperHandler) WrapStreamableToolCall(ctx context.Context, endpoint StreamableToolCallEndpoint, tCtx *ToolContext) (StreamableToolCallEndpoint, error) { if h.wrapStreamableFn != nil { return h.wrapStreamableFn(ctx, endpoint, tCtx), nil } return endpoint, nil } type testModelWrapperHandler struct { *BaseChatModelAgentMiddleware fn func(context.Context, model.BaseChatModel, *ModelContext) model.BaseChatModel } func (h *testModelWrapperHandler) WrapModel(ctx context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) { return h.fn(ctx, m, mc), nil } func newTestInvokableToolCallWrapper(beforeFn, afterFn func()) func(context.Context, InvokableToolCallEndpoint, *ToolContext) InvokableToolCallEndpoint { return func(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) InvokableToolCallEndpoint { return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { if beforeFn != nil { beforeFn() } result, err := endpoint(ctx, argumentsInJSON, opts...) if afterFn != nil { afterFn() } return result, err } } } func newResultModifyingInvokableToolCallWrapper(modifyFn func(string) string) func(context.Context, InvokableToolCallEndpoint, *ToolContext) InvokableToolCallEndpoint { return func(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) InvokableToolCallEndpoint { return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { result, err := endpoint(ctx, argumentsInJSON, opts...) if err == nil && modifyFn != nil { result = modifyFn(result) } return result, err } } } func newTestStreamableToolCallWrapper(beforeFn, afterFn func()) func(context.Context, StreamableToolCallEndpoint, *ToolContext) StreamableToolCallEndpoint { return func(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) StreamableToolCallEndpoint { return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { if beforeFn != nil { beforeFn() } result, err := endpoint(ctx, argumentsInJSON, opts...) if afterFn != nil { afterFn() } return result, err } } } func TestHandlerExecutionOrder(t *testing.T) { t.Run("MultipleInstructionHandlersPipeline", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var capturedInstruction string cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { if len(msgs) > 0 && msgs[0].Role == schema.System { capturedInstruction = msgs[0].Content } return schema.AssistantMessage("response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Instruction: "Base instruction.", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testInstructionHandler{text: "Handler 1 addition."}, &testInstructionHandler{text: "Handler 2 addition."}, &testInstructionFuncHandler{fn: func(ctx context.Context, instruction string) (context.Context, string, error) { return ctx, instruction + "\nHandler 3 dynamic.", nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Contains(t, capturedInstruction, "Base instruction.") assert.Contains(t, capturedInstruction, "Handler 1 addition.") assert.Contains(t, capturedInstruction, "Handler 2 addition.") assert.Contains(t, capturedInstruction, "Handler 3 dynamic.") }) t.Run("MiddlewaresBeforeHandlers", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var capturedInstruction string cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { if len(msgs) > 0 && msgs[0].Role == schema.System { capturedInstruction = msgs[0].Content } return schema.AssistantMessage("response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Instruction: "Base.", Model: cm, Middlewares: []AgentMiddleware{ {AdditionalInstruction: "Middleware instruction."}, }, Handlers: []ChatModelAgentMiddleware{ &testInstructionHandler{text: "Handler instruction."}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } middlewareIdx := len(capturedInstruction) - len("Middleware instruction.") - len("\nHandler instruction.") handlerIdx := len(capturedInstruction) - len("Handler instruction.") assert.True(t, middlewareIdx < handlerIdx, "Middleware should be applied before Handler") }) } func TestToolsHandlerCombinations(t *testing.T) { t.Run("MultipleToolsHandlersAppend", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) tool1 := &fakeToolForTest{tarCount: 1} tool2 := &fakeToolForTest{tarCount: 2} cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() var capturedToolCount int cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { options := model.GetCommonOptions(&model.Options{}, opts...) capturedToolCount = len(options.Tools) return schema.AssistantMessage("response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{tool1}, }, }, Handlers: []ChatModelAgentMiddleware{ &testToolsHandler{tools: []tool.BaseTool{tool2}}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, 2, capturedToolCount) }) t.Run("ToolsFuncCanRemoveTools", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) tool1 := &namedTool{name: "tool1"} tool2 := &namedTool{name: "tool2"} tool3 := &namedTool{name: "tool3"} cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() var capturedToolNames []string cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { options := model.GetCommonOptions(&model.Options{}, opts...) for _, t := range options.Tools { capturedToolNames = append(capturedToolNames, t.Name) } return schema.AssistantMessage("response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{tool1, tool2, tool3}, }, }, Handlers: []ChatModelAgentMiddleware{ &testToolsFuncHandler{fn: func(ctx context.Context, tools []tool.BaseTool, returnDirectly map[string]bool) (context.Context, []tool.BaseTool, map[string]bool, error) { filtered := make([]tool.BaseTool, 0) for _, t := range tools { info, _ := t.Info(ctx) if info.Name != "tool2" { filtered = append(filtered, t) } } return ctx, filtered, returnDirectly, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Contains(t, capturedToolNames, "tool1") assert.NotContains(t, capturedToolNames, "tool2") assert.Contains(t, capturedToolNames, "tool3") }) t.Run("ReturnDirectlyModification", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) tool1 := &namedTool{name: "tool1"} cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: "tool1", Arguments: "{}"}}, }), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{tool1}, }, }, Handlers: []ChatModelAgentMiddleware{ &testToolsFuncHandler{fn: func(ctx context.Context, tools []tool.BaseTool, returnDirectly map[string]bool) (context.Context, []tool.BaseTool, map[string]bool, error) { for _, t := range tools { info, _ := t.Info(ctx) if info.Name == "tool1" { returnDirectly[info.Name] = true } } return ctx, tools, returnDirectly, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) eventCount := 0 for { event, ok := iter.Next() if !ok { break } eventCount++ if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.Message != nil && event.Output.MessageOutput.Message.Role == schema.Tool { assert.Equal(t, "tool1 result", event.Output.MessageOutput.Message.Content) } } assert.Equal(t, 2, eventCount) }) t.Run("DynamicToolCanBeCalledByModel", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) dynamicToolCalled := false dynamicTool := &callableTool{ name: "dynamic_tool", invokeFn: func() { dynamicToolCalled = true }, } info, _ := dynamicTool.Info(ctx) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using dynamic tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), nil).Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testToolsHandler{tools: []tool.BaseTool{dynamicTool}}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.True(t, dynamicToolCalled, "Dynamic tool should have been called") }) } func TestMessageRewriteHandlers(t *testing.T) { t.Run("BeforeModelRewriteStatePipeline", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var capturedMsgCount int cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { capturedMsgCount = len(msgs) return schema.AssistantMessage("response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Instruction: "instruction", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { state.Messages = append(state.Messages, schema.UserMessage("injected1")) return ctx, state, nil }}, &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { state.Messages = append(state.Messages, schema.UserMessage("injected2")) return ctx, state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("original")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, 4, capturedMsgCount) }) t.Run("AfterModelRewriteState", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) afterCalled := false cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testAfterModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { afterCalled = true assert.True(t, len(state.Messages) > 0) lastMsg := state.Messages[len(state.Messages)-1] assert.Equal(t, schema.Assistant, lastMsg.Role) return ctx, state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.True(t, afterCalled) }) } func TestToolCallWrapperHandlers(t *testing.T) { t.Run("MultipleToolWrappersPipeline", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) testTool := &namedTool{name: "test_tool"} info, _ := testTool.Info(ctx) var callOrder []string var mu sync.Mutex cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), nil).Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Handlers: []ChatModelAgentMiddleware{ &testToolWrapperHandler{wrapInvokableFn: newTestInvokableToolCallWrapper( func() { mu.Lock() callOrder = append(callOrder, "wrapper1-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper1-after") mu.Unlock() }, )}, &testToolWrapperHandler{wrapInvokableFn: newTestInvokableToolCallWrapper( func() { mu.Lock() callOrder = append(callOrder, "wrapper2-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper2-after") mu.Unlock() }, )}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, []string{"wrapper2-before", "wrapper1-before", "wrapper1-after", "wrapper2-after"}, callOrder) }) t.Run("StreamingToolWrappersPipeline", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) testTool := &streamingNamedTool{name: "streaming_tool"} info, _ := testTool.Info(ctx) var callOrder []string var mu sync.Mutex cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), }), nil).Times(1) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("done", nil), }), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Handlers: []ChatModelAgentMiddleware{ &testToolWrapperHandler{wrapStreamableFn: newTestStreamableToolCallWrapper( func() { mu.Lock() callOrder = append(callOrder, "wrapper1-stream-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper1-stream-after") mu.Unlock() }, )}, &testToolWrapperHandler{wrapStreamableFn: newTestStreamableToolCallWrapper( func() { mu.Lock() callOrder = append(callOrder, "wrapper2-stream-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper2-stream-after") mu.Unlock() }, )}, }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true, CheckPointStore: newBridgeStore()}) iter := r.Run(ctx, []Message{schema.UserMessage("test")}) var hasStreamingToolResult bool for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming && event.Output.MessageOutput.Role == schema.Tool { hasStreamingToolResult = true for { _, err := event.Output.MessageOutput.MessageStream.Recv() if err != nil { break } } } } assert.True(t, hasStreamingToolResult, "Should have streaming tool result") assert.Equal(t, []string{"wrapper2-stream-before", "wrapper1-stream-before", "wrapper1-stream-after", "wrapper2-stream-after"}, callOrder, "Streaming wrappers should be called in correct order") }) t.Run("ToolWrapperCanModifyResult", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) testTool := &namedTool{name: "test_tool"} info, _ := testTool.Info(ctx) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), nil).Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Handlers: []ChatModelAgentMiddleware{ &testToolWrapperHandler{wrapInvokableFn: newResultModifyingInvokableToolCallWrapper(func(result string) string { return "modified: " + result })}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.Message != nil && event.Output.MessageOutput.Message.Role == schema.Tool { assert.Equal(t, "modified: test_tool result", event.Output.MessageOutput.Message.Content) } } }) } func TestToolContextFunctions(t *testing.T) { t.Run("ModelContextToolsField", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) testTool := &namedTool{name: "base_tool"} info, _ := testTool.Info(ctx) var wrapperSeenTools []*schema.ToolInfo cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Handlers: []ChatModelAgentMiddleware{ &testModelWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, fn: func(_ context.Context, m model.BaseChatModel, mc *ModelContext) model.BaseChatModel { return &toolChainingTestModel{ inner: m, mc: mc, wrapFn: func(ctx context.Context, opts []model.Option) []model.Option { wrapperSeenTools = mc.Tools return opts }, } }, }, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Len(t, wrapperSeenTools, 1, "Wrapper should see base tool") assert.Equal(t, info.Name, wrapperSeenTools[0].Name, "Wrapper should see base_tool") }) } type toolChainingTestModel struct { inner model.BaseChatModel mc *ModelContext wrapFn func(ctx context.Context, opts []model.Option) []model.Option } func (m *toolChainingTestModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { if m.wrapFn != nil { opts = m.wrapFn(ctx, opts) } return m.inner.Generate(ctx, msgs, opts...) } func (m *toolChainingTestModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { if m.wrapFn != nil { opts = m.wrapFn(ctx, opts) } return m.inner.Stream(ctx, msgs, opts...) } func (m *toolChainingTestModel) BindTools(tools []*schema.ToolInfo) error { return nil } func TestContextPropagation(t *testing.T) { t.Run("ContextPassedThroughBeforeModelHandlers", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) type ctxKey string const key1 ctxKey = "key1" const key2 ctxKey = "key2" var handler2ReceivedValue1 interface{} cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return context.WithValue(ctx, key1, "value1"), state, nil }}, &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { handler2ReceivedValue1 = ctx.Value(key1) return context.WithValue(ctx, key2, "value2"), state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, "value1", handler2ReceivedValue1, "Handler 2 should receive context value set by Handler 1") }) t.Run("BeforeAgentContextPropagation", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) type ctxKey string const key1 ctxKey = "key1" var handler2ReceivedValue interface{} cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { return context.WithValue(ctx, key1, "value1"), runCtx, nil }}, &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { handler2ReceivedValue = ctx.Value(key1) return ctx, runCtx, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, "value1", handler2ReceivedValue, "Handler 2 should receive context value set by Handler 1 during BeforeAgent") }) } func TestCustomHandler(t *testing.T) { t.Run("CustomHandlerWithState", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) customHandler := &countingHandler{} agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{customHandler}, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, 1, customHandler.beforeAgentCount) assert.Equal(t, 1, customHandler.beforeModelCount) assert.Equal(t, 1, customHandler.afterModelCount) }) } func TestHandlerErrorHandling(t *testing.T) { t.Run("BeforeAgentErrorStopsRun", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { return ctx, runCtx, assert.AnError }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{ Messages: []*schema.Message{schema.UserMessage("test")}, }) var gotErr error for { event, ok := iter.Next() if !ok { break } if event.Err != nil { gotErr = event.Err } } assert.Error(t, gotErr) assert.Contains(t, gotErr.Error(), "BeforeAgent failed") }) } type namedTool struct { name string } func (t *namedTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: t.name, Desc: t.name + " description"}, nil } func (t *namedTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return t.name + " result", nil } type streamingNamedTool struct { name string } func (t *streamingNamedTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: t.name, Desc: t.name + " description"}, nil } func (t *streamingNamedTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return t.name + " result", nil } func (t *streamingNamedTool) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) { return schema.StreamReaderFromArray([]string{t.name + " stream result"}), nil } type callableTool struct { name string invokeFn func() } func (t *callableTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: t.name, Desc: t.name + " description"}, nil } func (t *callableTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { if t.invokeFn != nil { t.invokeFn() } return t.name + " result", nil } type countingHandler struct { *BaseChatModelAgentMiddleware beforeAgentCount int beforeModelCount int afterModelCount int mu sync.Mutex } func (h *countingHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { h.mu.Lock() h.beforeAgentCount++ h.mu.Unlock() return ctx, runCtx, nil } func (h *countingHandler) BeforeModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { h.mu.Lock() h.beforeModelCount++ h.mu.Unlock() return ctx, state, nil } func (h *countingHandler) AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { h.mu.Lock() h.afterModelCount++ h.mu.Unlock() return ctx, state, nil } func newTestModelWrapperFn(beforeFn, afterFn func()) func(context.Context, model.BaseChatModel, *ModelContext) model.BaseChatModel { return func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return &testWrappedModel{ inner: m, beforeFn: beforeFn, afterFn: afterFn, } } } type testWrappedModel struct { inner model.BaseChatModel beforeFn func() afterFn func() } func (m *testWrappedModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { if m.beforeFn != nil { m.beforeFn() } result, err := m.inner.Generate(ctx, input, opts...) if m.afterFn != nil { m.afterFn() } return result, err } func (m *testWrappedModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { if m.beforeFn != nil { m.beforeFn() } result, err := m.inner.Stream(ctx, input, opts...) if m.afterFn != nil { m.afterFn() } return result, err } func TestModelWrapperHandlers(t *testing.T) { t.Run("MultipleModelWrappersPipeline", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var callOrder []string var mu sync.Mutex cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testModelWrapperHandler{fn: newTestModelWrapperFn( func() { mu.Lock() callOrder = append(callOrder, "wrapper1-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper1-after") mu.Unlock() }, )}, &testModelWrapperHandler{fn: newTestModelWrapperFn( func() { mu.Lock() callOrder = append(callOrder, "wrapper2-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper2-after") mu.Unlock() }, )}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, []string{"wrapper1-before", "wrapper2-before", "wrapper2-after", "wrapper1-after"}, callOrder) }) t.Run("ModelWrapperBeforeAfterCallOrder", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var callOrder []string var mu sync.Mutex cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { mu.Lock() callOrder = append(callOrder, "model-generate") mu.Unlock() return schema.AssistantMessage("original response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testModelWrapperHandler{fn: newTestModelWrapperFn( func() { mu.Lock() callOrder = append(callOrder, "wrapper-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper-after") mu.Unlock() }, )}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, []string{"wrapper-before", "model-generate", "wrapper-after"}, callOrder) }) t.Run("ModelWrapperWithTools", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) testTool := &namedTool{name: "test_tool"} info, _ := testTool.Info(ctx) var callOrder []string var mu sync.Mutex cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { mu.Lock() callOrder = append(callOrder, "model-call") mu.Unlock() return schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), nil }).Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { mu.Lock() callOrder = append(callOrder, "model-call") mu.Unlock() return schema.AssistantMessage("done", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Handlers: []ChatModelAgentMiddleware{ &testModelWrapperHandler{fn: newTestModelWrapperFn( func() { mu.Lock() callOrder = append(callOrder, "wrapper-before") mu.Unlock() }, func() { mu.Lock() callOrder = append(callOrder, "wrapper-after") mu.Unlock() }, )}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, []string{ "wrapper-before", "model-call", "wrapper-after", "wrapper-before", "model-call", "wrapper-after", }, callOrder) }) } type simpleChatModelWithoutCallbacks struct { generateFn func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) streamFn func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) } func (m *simpleChatModelWithoutCallbacks) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { if m.generateFn != nil { return m.generateFn(ctx, input, opts...) } return schema.AssistantMessage("default response", nil), nil } func (m *simpleChatModelWithoutCallbacks) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { if m.streamFn != nil { return m.streamFn(ctx, input, opts...) } return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("default response", nil)}), nil } func (m *simpleChatModelWithoutCallbacks) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { return m, nil } func newInputModifyingWrapperFn(inputPrefix string) func(context.Context, model.BaseChatModel, *ModelContext) model.BaseChatModel { return func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return &inputOutputModifyingModel{ inner: m, inputPrefix: inputPrefix, } } } type inputOutputModifyingModel struct { inner model.BaseChatModel inputPrefix string } func (m *inputOutputModifyingModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { modifiedMessages := make([]*schema.Message, len(input)) for i, msg := range input { if msg.Role == schema.User { modifiedMessages[i] = schema.UserMessage(m.inputPrefix + msg.Content) } else { modifiedMessages[i] = msg } } return m.inner.Generate(ctx, modifiedMessages, opts...) } func (m *inputOutputModifyingModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { modifiedMessages := make([]*schema.Message, len(input)) for i, msg := range input { if msg.Role == schema.User { modifiedMessages[i] = schema.UserMessage(m.inputPrefix + msg.Content) } else { modifiedMessages[i] = msg } } return m.inner.Stream(ctx, modifiedMessages, opts...) } func TestModelWrapper_InputModification(t *testing.T) { t.Run("ModelWrapperModifiesInput_Generate", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var modelReceivedInput []*schema.Message cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { modelReceivedInput = input return schema.AssistantMessage("original response", nil), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testModelWrapperHandler{fn: newInputModifyingWrapperFn("[WRAPPER]")}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test input")}}) for { _, ok := iter.Next() if !ok { break } } assert.NotNil(t, modelReceivedInput) assert.True(t, len(modelReceivedInput) > 0) found := false for _, msg := range modelReceivedInput { if msg.Content == "[WRAPPER]test input" { found = true break } } assert.True(t, found, "Model should receive wrapper-modified input") }) t.Run("ModelWrapperModifiesInput_Stream", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var modelReceivedInput []*schema.Message cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { modelReceivedInput = input return schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("chunk1", nil), schema.AssistantMessage("chunk2", nil), }), nil }).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testModelWrapperHandler{fn: newInputModifyingWrapperFn("[WRAPPER]")}, }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true, CheckPointStore: newBridgeStore()}) iter := r.Run(ctx, []Message{schema.UserMessage("test input")}) for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming && event.Output.MessageOutput.Role == schema.Assistant { for { _, err := event.Output.MessageOutput.MessageStream.Recv() if err != nil { break } } } } assert.NotNil(t, modelReceivedInput) assert.True(t, len(modelReceivedInput) > 0) found := false for _, msg := range modelReceivedInput { if msg.Content == "[WRAPPER]test input" { found = true break } } assert.True(t, found, "Model should receive wrapper-modified input") }) } func TestRunLocalValueFunctions(t *testing.T) { t.Run("SetAndGetRunLocalValue", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var capturedValue any var capturedFound bool cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { err := SetRunLocalValue(ctx, "test_key", "test_value") assert.NoError(t, err) return ctx, state, nil }}, &testAfterModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { val, found, err := GetRunLocalValue(ctx, "test_key") assert.NoError(t, err) capturedValue = val capturedFound = found return ctx, state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.True(t, capturedFound, "Value should be found") assert.Equal(t, "test_value", capturedValue, "Value should match what was set") }) t.Run("DeleteRunLocalValue", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var valueAfterDelete any var foundAfterDelete bool cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { err := SetRunLocalValue(ctx, "delete_key", "delete_value") assert.NoError(t, err) err = DeleteRunLocalValue(ctx, "delete_key") assert.NoError(t, err) return ctx, state, nil }}, &testAfterModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { val, found, err := GetRunLocalValue(ctx, "delete_key") assert.NoError(t, err) valueAfterDelete = val foundAfterDelete = found return ctx, state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.False(t, foundAfterDelete, "Value should not be found after deletion") assert.Nil(t, valueAfterDelete, "Value should be nil after deletion") }) t.Run("GetNonExistentKey", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) var capturedValue any var capturedFound bool cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { val, found, err := GetRunLocalValue(ctx, "non_existent_key") assert.NoError(t, err) capturedValue = val capturedFound = found return ctx, state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.False(t, capturedFound, "Non-existent key should not be found") assert.Nil(t, capturedValue, "Non-existent key should return nil value") }) t.Run("RunLocalValueOutsideContext", func(t *testing.T) { ctx := context.Background() err := SetRunLocalValue(ctx, "key", "value") assert.Error(t, err, "SetRunLocalValue should fail outside agent context") assert.Contains(t, err.Error(), "SetRunLocalValue failed") _, _, err = GetRunLocalValue(ctx, "key") assert.Error(t, err, "GetRunLocalValue should fail outside agent context") assert.Contains(t, err.Error(), "GetRunLocalValue failed") err = DeleteRunLocalValue(ctx, "key") assert.Error(t, err, "DeleteRunLocalValue should fail outside agent context") assert.Contains(t, err.Error(), "DeleteRunLocalValue failed") }) t.Run("RunLocalValuePersistsAcrossModelCalls", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) testTool := &namedTool{name: "test_tool"} info, _ := testTool.Info(ctx) var firstCallValue any var secondCallValue any callCount := 0 cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "call1", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), nil).Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { callCount++ if callCount == 1 { err := SetRunLocalValue(ctx, "persist_key", "persist_value") assert.NoError(t, err) val, _, _ := GetRunLocalValue(ctx, "persist_key") firstCallValue = val } else { val, _, _ := GetRunLocalValue(ctx, "persist_key") secondCallValue = val } return ctx, state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, "persist_value", firstCallValue, "First call should set value") assert.Equal(t, "persist_value", secondCallValue, "Value should persist to second model call") }) } func TestHandlerErrorPropagation(t *testing.T) { t.Run("BeforeModelRewriteStateErrorStopsRun", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return ctx, state, assert.AnError }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) var gotErr error for { event, ok := iter.Next() if !ok { break } if event.Err != nil { gotErr = event.Err } } assert.Error(t, gotErr) }) t.Run("AfterModelRewriteStateErrorStopsRun", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("response", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testAfterModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return ctx, state, assert.AnError }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) var gotErr error for { event, ok := iter.Next() if !ok { break } if event.Err != nil { gotErr = event.Err } } assert.Error(t, gotErr) }) t.Run("MultipleHandlersFirstErrorStops", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) secondHandlerCalled := false agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { return ctx, state, assert.AnError }}, &testBeforeModelRewriteStateHandler{fn: func(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { secondHandlerCalled = true return ctx, state, nil }}, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.False(t, secondHandlerCalled, "Second handler should not be called after first handler error") }) } func TestToolContextInWrappers(t *testing.T) { t.Run("ToolContextHasCorrectNameAndCallID", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) testTool := &namedTool{name: "context_test_tool"} info, _ := testTool.Info(ctx) var capturedToolName string var capturedCallID string cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Using tool", []schema.ToolCall{ {ID: "test_call_id_123", Function: schema.FunctionCall{Name: info.Name, Arguments: "{}"}}, }), nil).Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("done", nil), nil).Times(1) agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: cm, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, }, }, Handlers: []ChatModelAgentMiddleware{ &testToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapInvokableFn: func(_ context.Context, endpoint InvokableToolCallEndpoint, tCtx *ToolContext) InvokableToolCallEndpoint { capturedToolName = tCtx.Name capturedCallID = tCtx.CallID return endpoint }, }, }, }) assert.NoError(t, err) iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) for { _, ok := iter.Next() if !ok { break } } assert.Equal(t, "context_test_tool", capturedToolName, "ToolContext should have correct tool name") assert.Equal(t, "test_call_id_123", capturedCallID, "ToolContext should have correct call ID") }) } ================================================ FILE: adk/instruction.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "fmt" "strings" "github.com/cloudwego/eino/adk/internal" ) const ( TransferToAgentInstruction = `Available other agents: %s Decision rule: - If you're best suited for the question according to your description: ANSWER - If another agent is better according its description: CALL '%s' function with their agent name When transferring: OUTPUT ONLY THE FUNCTION CALL` TransferToAgentInstructionChinese = `可用的其他 agent:%s 决策规则: - 如果根据你的职责描述,你最适合回答这个问题:ANSWER - 如果根据其职责描述,另一个 agent 更适合:调用 %s 函数,并传入该 agent 的名称 当进行移交时:只输出函数调用,不要输出其他任何内容` agentDescriptionTpl = "\n- Agent name: %s\n Agent description: %s" agentDescriptionTplChinese = "\n- Agent 名字: %s\n Agent 描述: %s" ) func genTransferToAgentInstruction(ctx context.Context, agents []Agent) string { tpl := internal.SelectPrompt(internal.I18nPrompts{ English: agentDescriptionTpl, Chinese: agentDescriptionTplChinese, }) instruction := internal.SelectPrompt(internal.I18nPrompts{ English: TransferToAgentInstruction, Chinese: TransferToAgentInstructionChinese, }) var sb strings.Builder for _, agent := range agents { sb.WriteString(fmt.Sprintf(tpl, agent.Name(ctx), agent.Description(ctx))) } return fmt.Sprintf(instruction, sb.String(), TransferToAgentToolName) } ================================================ FILE: adk/interface.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "context" "encoding/gob" "fmt" "io" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/schema" ) // ComponentOfAgent is the component type identifier for ADK agents in callbacks. // Use this to filter callback events to only agent-related events. const ComponentOfAgent components.Component = "Agent" type Message = *schema.Message type MessageStream = *schema.StreamReader[Message] type MessageVariant struct { IsStreaming bool Message Message MessageStream MessageStream // message role: Assistant or Tool Role schema.RoleType // only used when Role is Tool ToolName string } // EventFromMessage wraps a message or stream into an AgentEvent with role metadata. func EventFromMessage(msg Message, msgStream MessageStream, role schema.RoleType, toolName string) *AgentEvent { return &AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: msgStream != nil, Message: msg, MessageStream: msgStream, Role: role, ToolName: toolName, }, }, } } type messageVariantSerialization struct { IsStreaming bool Message Message MessageStream Message Role schema.RoleType ToolName string } func (mv *MessageVariant) GobEncode() ([]byte, error) { s := &messageVariantSerialization{ IsStreaming: mv.IsStreaming, Message: mv.Message, Role: mv.Role, ToolName: mv.ToolName, } if mv.IsStreaming { var messages []Message for { frame, err := mv.MessageStream.Recv() if err == io.EOF { break } if err != nil { return nil, fmt.Errorf("error receiving message stream: %w", err) } messages = append(messages, frame) } m, err := schema.ConcatMessages(messages) if err != nil { return nil, fmt.Errorf("failed to encode message: cannot concat message stream: %w", err) } s.MessageStream = m } buf := &bytes.Buffer{} err := gob.NewEncoder(buf).Encode(s) if err != nil { return nil, fmt.Errorf("failed to gob encode message variant: %w", err) } return buf.Bytes(), nil } func (mv *MessageVariant) GobDecode(b []byte) error { s := &messageVariantSerialization{} err := gob.NewDecoder(bytes.NewReader(b)).Decode(s) if err != nil { return fmt.Errorf("failed to decoding message variant: %w", err) } mv.IsStreaming = s.IsStreaming mv.Message = s.Message mv.Role = s.Role mv.ToolName = s.ToolName if s.MessageStream != nil { mv.MessageStream = schema.StreamReaderFromArray([]*schema.Message{s.MessageStream}) } return nil } func (mv *MessageVariant) GetMessage() (Message, error) { var message Message if mv.IsStreaming { var err error message, err = schema.ConcatMessageStream(mv.MessageStream) if err != nil { return nil, err } } else { message = mv.Message } return message, nil } type TransferToAgentAction struct { DestAgentName string } type AgentOutput struct { MessageOutput *MessageVariant CustomizedOutput any } // NewTransferToAgentAction creates an action to transfer to the specified agent. func NewTransferToAgentAction(destAgentName string) *AgentAction { return &AgentAction{TransferToAgent: &TransferToAgentAction{DestAgentName: destAgentName}} } // NewExitAction creates an action that signals the agent to exit. func NewExitAction() *AgentAction { return &AgentAction{Exit: true} } // AgentAction represents actions that an agent can emit during execution. // // Action Scoping in Agent Tools: // When an agent is wrapped as an agent tool (via NewAgentTool), actions emitted by the inner agent // are scoped to the tool boundary: // - Interrupted: Propagated via CompositeInterrupt to allow proper interrupt/resume across boundaries // - Exit, TransferToAgent, BreakLoop: Ignored outside the agent tool; these actions only affect // the inner agent's execution and do not propagate to the parent agent // // This scoping ensures that nested agents cannot unexpectedly terminate or transfer control // of their parent agent's execution flow. type AgentAction struct { Exit bool Interrupted *InterruptInfo TransferToAgent *TransferToAgentAction BreakLoop *BreakLoopAction CustomizedAction any internalInterrupted *core.InterruptSignal } // RunStep CheckpointSchema: persisted via serialization.RunCtx (gob). type RunStep struct { agentName string } func init() { schema.RegisterName[[]RunStep]("eino_run_step_list") } func (r *RunStep) String() string { return r.agentName } func (r *RunStep) Equals(r1 RunStep) bool { return r.agentName == r1.agentName } func (r *RunStep) GobEncode() ([]byte, error) { s := &runStepSerialization{AgentName: r.agentName} buf := &bytes.Buffer{} err := gob.NewEncoder(buf).Encode(s) if err != nil { return nil, fmt.Errorf("failed to gob encode RunStep: %w", err) } return buf.Bytes(), nil } func (r *RunStep) GobDecode(b []byte) error { s := &runStepSerialization{} err := gob.NewDecoder(bytes.NewReader(b)).Decode(s) if err != nil { return fmt.Errorf("failed to gob decode RunStep: %w", err) } r.agentName = s.AgentName return nil } type runStepSerialization struct { AgentName string } // AgentEvent CheckpointSchema: persisted via serialization.RunCtx (gob). type AgentEvent struct { AgentName string // RunPath represents the execution path from root agent to the current event source. // This field is managed entirely by the eino framework and cannot be set by end-users // because RunStep's fields are unexported. The framework sets RunPath exactly once: // - flowAgent sets it when the event has no RunPath (len == 0) // - agentTool prepends parent RunPath when forwarding events from nested agents RunPath []RunStep Output *AgentOutput Action *AgentAction Err error } type AgentInput struct { Messages []Message EnableStreaming bool } //go:generate mockgen -destination ../internal/mock/adk/Agent_mock.go --package adk -source interface.go type Agent interface { Name(ctx context.Context) string Description(ctx context.Context) string // Run runs the agent. // The returned AgentEvent within the AsyncIterator must be safe to modify. // If the returned AgentEvent within the AsyncIterator contains MessageStream, // the MessageStream MUST be exclusive and safe to be received directly. // NOTE: it's recommended to use SetAutomaticClose() on the MessageStream of AgentEvents emitted by AsyncIterator, // so that even the events are not processed, the MessageStream can still be closed. Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] } type OnSubAgents interface { OnSetSubAgents(ctx context.Context, subAgents []Agent) error OnSetAsSubAgent(ctx context.Context, parent Agent) error OnDisallowTransferToParent(ctx context.Context) error } type ResumableAgent interface { Agent Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } ================================================ FILE: adk/internal/config.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package internal provides adk internal utils. package internal import ( "fmt" "sync/atomic" ) // Language represents the language setting for the ADK built-in prompts. type Language uint8 const ( // LanguageEnglish represents English language. LanguageEnglish Language = iota // LanguageChinese represents Chinese language. LanguageChinese ) var language atomic.Value // SetLanguage sets the language for the ADK built-in prompts. // The default language is English if not explicitly set. func SetLanguage(lang Language) error { if lang != LanguageEnglish && lang != LanguageChinese { return fmt.Errorf("invalid language: %v", lang) } language.Store(lang) return nil } // GetLanguage returns the current language setting for the ADK built-in prompts. // Returns LanguageEnglish if no language has been set. func getLanguage() Language { if l, ok := language.Load().(Language); ok { return l } return LanguageEnglish } // I18nPrompts holds prompt strings for different languages. type I18nPrompts struct { English string Chinese string } // SelectPrompt returns the appropriate prompt string based on the current language setting. // Returns an error if the current language is not supported. func SelectPrompt(prompts I18nPrompts) string { lang := getLanguage() switch lang { case LanguageEnglish: return prompts.English case LanguageChinese: return prompts.Chinese default: // unreachable panic(fmt.Sprintf("invalid language: %v", lang)) } } ================================================ FILE: adk/interrupt.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "context" "encoding/gob" "errors" "fmt" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/schema" ) // ResumeInfo holds all the information necessary to resume an interrupted agent execution. // It is created by the framework and passed to an agent's Resume method. type ResumeInfo struct { // EnableStreaming indicates whether the original execution was in streaming mode. EnableStreaming bool // Deprecated: use InterruptContexts from the embedded InterruptInfo for user-facing details, // and GetInterruptState for internal state retrieval. *InterruptInfo WasInterrupted bool InterruptState any IsResumeTarget bool ResumeData any } // InterruptInfo contains all the information about an interruption event. // It is created by the framework when an agent returns an interrupt action. type InterruptInfo struct { Data any // InterruptContexts provides a structured, user-facing view of the interrupt chain. // Each context represents a step in the agent hierarchy that was interrupted. InterruptContexts []*InterruptCtx } // Interrupt creates a basic interrupt action. // This is used when an agent needs to pause its execution to request external input or intervention, // but does not need to save any internal state to be restored upon resumption. // The `info` parameter is user-facing data that describes the reason for the interrupt. func Interrupt(ctx context.Context, info any) *AgentEvent { var rp []RunStep rCtx := getRunCtx(ctx) if rCtx != nil { rp = rCtx.RunPath } is, err := core.Interrupt(ctx, info, nil, nil, core.WithLayerPayload(rp)) if err != nil { return &AgentEvent{Err: err} } contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) return &AgentEvent{ Action: &AgentAction{ Interrupted: &InterruptInfo{ InterruptContexts: contexts, }, internalInterrupted: is, }, } } // StatefulInterrupt creates an interrupt action that also saves the agent's internal state. // This is used when an agent has internal state that must be restored for it to continue correctly. // The `info` parameter is user-facing data describing the interrupt. // The `state` parameter is the agent's internal state object, which will be serialized and stored. func StatefulInterrupt(ctx context.Context, info any, state any) *AgentEvent { var rp []RunStep rCtx := getRunCtx(ctx) if rCtx != nil { rp = rCtx.RunPath } is, err := core.Interrupt(ctx, info, state, nil, core.WithLayerPayload(rp)) if err != nil { return &AgentEvent{Err: err} } contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) return &AgentEvent{ Action: &AgentAction{ Interrupted: &InterruptInfo{ InterruptContexts: contexts, }, internalInterrupted: is, }, } } // CompositeInterrupt creates an interrupt action for a workflow agent. // It combines the interrupts from one or more of its sub-agents into a single, cohesive interrupt. // This is used by workflow agents (like Sequential, Parallel, or Loop) to propagate interrupts from their children. // The `info` parameter is user-facing data describing the workflow's own reason for interrupting. // The `state` parameter is the workflow agent's own state (e.g., the index of the sub-agent that was interrupted). // The `subInterruptSignals` is a variadic list of the InterruptSignal objects from the interrupted sub-agents. func CompositeInterrupt(ctx context.Context, info any, state any, subInterruptSignals ...*InterruptSignal) *AgentEvent { var rp []RunStep rCtx := getRunCtx(ctx) if rCtx != nil { rp = rCtx.RunPath } is, err := core.Interrupt(ctx, info, state, subInterruptSignals, core.WithLayerPayload(rp)) if err != nil { return &AgentEvent{Err: err} } contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) return &AgentEvent{ Action: &AgentAction{ Interrupted: &InterruptInfo{ InterruptContexts: contexts, }, internalInterrupted: is, }, } } // Address represents the unique, hierarchical address of a component within an execution. // It is a slice of AddressSegments, where each segment represents one level of nesting. // This is a type alias for core.Address. See the core package for more details. type Address = core.Address type AddressSegment = core.AddressSegment type AddressSegmentType = core.AddressSegmentType const ( AddressSegmentAgent AddressSegmentType = "agent" AddressSegmentTool AddressSegmentType = "tool" ) var allowedAddressSegmentTypes = []AddressSegmentType{AddressSegmentAgent, AddressSegmentTool} // AppendAddressSegment adds an address segment for the current execution context. func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID string) context.Context { return core.AppendAddressSegment(ctx, segType, segID, "") } // InterruptCtx provides a structured, user-facing view of a single point of interruption. // It contains the ID and Address of the interrupted component, as well as user-defined info. // This is a type alias for core.InterruptCtx. See the core package for more details. type InterruptCtx = core.InterruptCtx type InterruptSignal = core.InterruptSignal // FromInterruptContexts converts user-facing interrupt contexts to an interrupt signal. func FromInterruptContexts(contexts []*InterruptCtx) *InterruptSignal { return core.FromInterruptContexts(contexts) } // WithCheckPointID sets the checkpoint ID used for interruption persistence. func WithCheckPointID(id string) AgentRunOption { return WrapImplSpecificOptFn(func(t *options) { t.checkPointID = &id }) } func init() { schema.RegisterName[*serialization]("_eino_adk_serialization") schema.RegisterName[*WorkflowInterruptInfo]("_eino_adk_workflow_interrupt_info") } // serialization CheckpointSchema: root checkpoint payload (gob). // Any type tagged with `CheckpointSchema:` is persisted and must remain backward compatible. type serialization struct { RunCtx *runContext // deprecated: still keep it here for backward compatibility Info *InterruptInfo EnableStreaming bool InterruptID2Address map[string]Address InterruptID2State map[string]core.InterruptState } func (r *Runner) loadCheckPoint(ctx context.Context, checkpointID string) ( context.Context, *runContext, *ResumeInfo, error) { data, existed, err := r.store.Get(ctx, checkpointID) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get checkpoint from store: %w", err) } if !existed { return nil, nil, nil, fmt.Errorf("checkpoint[%s] not exist", checkpointID) } data = preprocessADKCheckpoint(data) s := &serialization{} err = gob.NewDecoder(bytes.NewReader(data)).Decode(s) if err != nil { return nil, nil, nil, fmt.Errorf("failed to decode checkpoint: %w", err) } ctx = core.PopulateInterruptState(ctx, s.InterruptID2Address, s.InterruptID2State) return ctx, s.RunCtx, &ResumeInfo{ EnableStreaming: s.EnableStreaming, InterruptInfo: s.Info, }, nil } // preprocessADKCheckpoint fixes a gob incompatibility when resuming old ChatModelAgent/DeepAgents checkpoints. // // Background // - ADK checkpoints are gob-encoded. // - Some values inside checkpoints are stored as `any`, so gob includes a concrete type name // string in the wire format and uses that name to pick the local Go type to decode into. // // Problem (v0.8.0-v0.8.3 checkpoints) // - In v0.8.0-v0.8.3, *State was registered under the name "_eino_adk_react_state" AND // implemented GobEncode/GobDecode, so the wire format for that name is "GobEncoder payload" // (opaque bytes). // - In v0.7.*, the same name "_eino_adk_react_state" was used but encoded as a normal struct // (no GobEncode). Gob treats these two wire formats as incompatible. // - Gob only allows one local Go type per name. Today we register "_eino_adk_react_state" to // a v0.7-compatible struct decoder (stateV07). If we try to decode a v0.8.0-v0.8.3 // checkpoint under that same name, gob fails with a "want struct; got non-struct" mismatch. // // Solution // - We keep "_eino_adk_react_state" mapped to the v0.7 decoder. // - For v0.8.0-v0.8.3 checkpoints only, we rewrite the on-wire name to a same-length alias // "_eino_adk_state_v080_", which is registered to a GobDecoder-compatible type (stateV080). // - The alias is the same length as the original, so we can safely replace the length-prefixed // bytes without re-encoding the whole stream. func preprocessADKCheckpoint(data []byte) []byte { const ( lenPrefixedReactStateName = "\x15" + stateGobNameV07 lenPrefixedCompatName = "\x15" + stateGobNameV080 lenPrefixedStateSerializationName = "\x12stateSerialization" ) // the following line checks whether the checkpoint is persisted through v0.8.0-v0.8.3 if !bytes.Contains(data, []byte(lenPrefixedReactStateName)) || !bytes.Contains(data, []byte(lenPrefixedStateSerializationName)) { return data } return bytes.ReplaceAll(data, []byte(lenPrefixedReactStateName), []byte(lenPrefixedCompatName)) } func (r *Runner) saveCheckPoint( ctx context.Context, key string, info *InterruptInfo, is *core.InterruptSignal, ) error { runCtx := getRunCtx(ctx) id2Addr, id2State := core.SignalToPersistenceMaps(is) buf := &bytes.Buffer{} err := gob.NewEncoder(buf).Encode(&serialization{ RunCtx: runCtx, Info: info, InterruptID2Address: id2Addr, InterruptID2State: id2State, EnableStreaming: r.enableStreaming, }) if err != nil { return fmt.Errorf("failed to encode checkpoint: %w", err) } return r.store.Set(ctx, key, buf.Bytes()) } const bridgeCheckpointID = "adk_react_mock_key" func newBridgeStore() *bridgeStore { return &bridgeStore{} } func newResumeBridgeStore(data []byte) *bridgeStore { return &bridgeStore{ Data: data, Valid: true, } } type bridgeStore struct { Data []byte Valid bool } func (m *bridgeStore) Get(_ context.Context, _ string) ([]byte, bool, error) { if m.Valid { return m.Data, true, nil } return nil, false, nil } func (m *bridgeStore) Set(_ context.Context, _ string, checkPoint []byte) error { m.Data = checkPoint m.Valid = true return nil } func getNextResumeAgent(ctx context.Context, info *ResumeInfo) (string, error) { nextAgents, err := core.GetNextResumptionPoints(ctx) if err != nil { return "", fmt.Errorf("failed to get next agent leading to interruption: %w", err) } if len(nextAgents) == 0 { return "", errors.New("no child agents leading to interrupted agent were found") } if len(nextAgents) > 1 { return "", errors.New("agent has multiple child agents leading to interruption, " + "but concurrent transfer is not supported") } // get the single next agent to delegate to. var nextAgentID string for id := range nextAgents { nextAgentID = id break } return nextAgentID, nil } func getNextResumeAgents(ctx context.Context, info *ResumeInfo) (map[string]bool, error) { nextAgents, err := core.GetNextResumptionPoints(ctx) if err != nil { return nil, fmt.Errorf("failed to get next agents leading to interruption: %w", err) } if len(nextAgents) == 0 { return nil, errors.New("no child agents leading to interrupted agent were found") } return nextAgents, nil } func buildResumeInfo(ctx context.Context, nextAgentID string, info *ResumeInfo) ( context.Context, *ResumeInfo) { ctx = AppendAddressSegment(ctx, AddressSegmentAgent, nextAgentID) nextResumeInfo := &ResumeInfo{ EnableStreaming: info.EnableStreaming, InterruptInfo: info.InterruptInfo, } wasInterrupted, hasState, state := core.GetInterruptState[any](ctx) nextResumeInfo.WasInterrupted = wasInterrupted if hasState { nextResumeInfo.InterruptState = state } if wasInterrupted { isResumeTarget, hasData, data := core.GetResumeContext[any](ctx) nextResumeInfo.IsResumeTarget = isResumeTarget if hasData { nextResumeInfo.ResumeData = data } } ctx = updateRunPathOnly(ctx, nextAgentID) return ctx, nextResumeInfo } ================================================ FILE: adk/interrupt_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "context" "errors" "fmt" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) type interruptTestToolsHandler struct { *BaseChatModelAgentMiddleware tools []tool.BaseTool } func TestPreprocessADKCheckpoint(t *testing.T) { t.Run("no-op when missing markers", func(t *testing.T) { in := []byte("random") out := preprocessADKCheckpoint(append([]byte(nil), in...)) assert.Equal(t, in, out) }) t.Run("rewrite legacy name for v0.8.0-v0.8.3", func(t *testing.T) { const ( lenPrefixedReactStateName = "\x15" + stateGobNameV07 lenPrefixedCompatName = "\x15" + stateGobNameV080 lenPrefixedStateSerializationName = "\x12stateSerialization" ) in := []byte(lenPrefixedReactStateName + "xxx" + lenPrefixedStateSerializationName + "yyy") out := preprocessADKCheckpoint(append([]byte(nil), in...)) assert.True(t, bytes.Contains(out, []byte(lenPrefixedCompatName))) assert.False(t, bytes.Contains(out, []byte(lenPrefixedReactStateName))) }) } func (h *interruptTestToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { runCtx.Tools = append(runCtx.Tools, h.tools...) return ctx, runCtx, nil } func TestSaveAgentEventWrapper(t *testing.T) { sr, sw := schema.Pipe[Message](1) sw.Send(schema.UserMessage("test"), nil) sw.Close() sr = sr.Copy(2)[1] w := &agentEventWrapper{ AgentEvent: &AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, MessageStream: sr, }, }, RunPath: []RunStep{ { "a1", }, { "a2", }, }, }, mu: sync.Mutex{}, concatenatedMessage: nil, } _, err := getMessageFromWrappedEvent(w) assert.NoError(t, err) buf, err := w.GobEncode() assert.NoError(t, err) assert.NoError(t, err) w1 := &agentEventWrapper{} err = w1.GobDecode(buf) assert.NoError(t, err) } func TestInterruptFunctionsPopulateInterruptContextsImmediately(t *testing.T) { ctx := context.Background() ctx, _ = initRunCtx(ctx, "TestAgent", &AgentInput{Messages: []Message{}}) ctx = AppendAddressSegment(ctx, AddressSegmentAgent, "TestAgent") t.Run("Interrupt populates InterruptContexts", func(t *testing.T) { event := Interrupt(ctx, "test info") assert.NotNil(t, event.Action) assert.NotNil(t, event.Action.Interrupted) assert.NotNil(t, event.Action.Interrupted.InterruptContexts) assert.Equal(t, 1, len(event.Action.Interrupted.InterruptContexts)) assert.Equal(t, "test info", event.Action.Interrupted.InterruptContexts[0].Info) assert.True(t, event.Action.Interrupted.InterruptContexts[0].IsRootCause) assert.Equal(t, Address{ {Type: AddressSegmentAgent, ID: "TestAgent"}, }, event.Action.Interrupted.InterruptContexts[0].Address) }) t.Run("StatefulInterrupt populates InterruptContexts", func(t *testing.T) { event := StatefulInterrupt(ctx, "stateful info", "my state") assert.NotNil(t, event.Action) assert.NotNil(t, event.Action.Interrupted) assert.NotNil(t, event.Action.Interrupted.InterruptContexts) assert.Equal(t, 1, len(event.Action.Interrupted.InterruptContexts)) assert.Equal(t, "stateful info", event.Action.Interrupted.InterruptContexts[0].Info) assert.True(t, event.Action.Interrupted.InterruptContexts[0].IsRootCause) }) t.Run("CompositeInterrupt populates InterruptContexts with filtered parent chain", func(t *testing.T) { subCtx := AppendAddressSegment(ctx, AddressSegmentAgent, "SubAgent") subEvent := Interrupt(subCtx, "sub info") event := CompositeInterrupt(ctx, "composite info", "composite state", subEvent.Action.internalInterrupted) assert.NotNil(t, event.Action) assert.NotNil(t, event.Action.Interrupted) assert.NotNil(t, event.Action.Interrupted.InterruptContexts) assert.Equal(t, 1, len(event.Action.Interrupted.InterruptContexts)) rootCause := event.Action.Interrupted.InterruptContexts[0] assert.Equal(t, "sub info", rootCause.Info) assert.True(t, rootCause.IsRootCause) assert.Equal(t, Address{ {Type: AddressSegmentAgent, ID: "TestAgent"}, {Type: AddressSegmentAgent, ID: "SubAgent"}, }, rootCause.Address) assert.NotNil(t, rootCause.Parent, "Parent should not be nil for composite interrupt") assert.Equal(t, "composite info", rootCause.Parent.Info) assert.Equal(t, Address{ {Type: AddressSegmentAgent, ID: "TestAgent"}, }, rootCause.Parent.Address) }) t.Run("Address only contains agent/tool segments", func(t *testing.T) { event := Interrupt(ctx, "test info") addr := event.Action.Interrupted.InterruptContexts[0].Address for _, seg := range addr { assert.True(t, seg.Type == AddressSegmentAgent || seg.Type == AddressSegmentTool, "Address should only contain agent/tool segments, got: %s", seg.Type) } }) } func TestSimpleInterrupt(t *testing.T) { data := "hello world" agent := &myAgent{ runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, Message: nil, MessageStream: schema.StreamReaderFromArray([]Message{ schema.UserMessage("hello "), schema.UserMessage("world"), }), }, }, }) intEvent := Interrupt(ctx, data) intEvent.Action.Interrupted.Data = data generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.True(t, info.WasInterrupted) assert.Nil(t, info.InterruptState) assert.True(t, info.EnableStreaming) assert.Equal(t, data, info.Data) assert.True(t, info.IsResumeTarget) iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Close() return iter }, } store := newMyStore() ctx := context.Background() runner := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: true, CheckPointStore: store, }) iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) _, ok := iter.Next() assert.True(t, ok) interruptEvent, ok := iter.Next() assert.True(t, ok) assert.Equal(t, data, interruptEvent.Action.Interrupted.Data) assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts[0].ID) assert.True(t, interruptEvent.Action.Interrupted.InterruptContexts[0].IsRootCause) assert.Equal(t, data, interruptEvent.Action.Interrupted.InterruptContexts[0].Info) assert.Equal(t, Address{{Type: AddressSegmentAgent, ID: "myAgent"}}, interruptEvent.Action.Interrupted.InterruptContexts[0].Address) _, ok = iter.Next() assert.False(t, ok) iter, err := runner.ResumeWithParams(ctx, "1", &ResumeParams{ Targets: map[string]any{ interruptEvent.Action.Interrupted.InterruptContexts[0].ID: nil, }, }) assert.NoError(t, err) _, ok = iter.Next() assert.False(t, ok) } func TestMultiAgentInterrupt(t *testing.T) { ctx := context.Background() sa1 := &myAgent{ name: "sa1", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "sa1", Action: &AgentAction{ TransferToAgent: &TransferToAgentAction{ DestAgentName: "sa2", }, }, }) generator.Close() return iter }, } sa2 := &myAgent{ name: "sa2", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() intEvent := StatefulInterrupt(ctx, "hello world", "temp state") intEvent.Action.Interrupted.Data = "hello world" generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.NotNil(t, info) assert.Equal(t, info.Data, "hello world") assert.True(t, info.WasInterrupted) assert.NotNil(t, info.InterruptState) assert.Equal(t, "temp state", info.InterruptState) assert.True(t, info.IsResumeTarget) assert.NotNil(t, info.ResumeData) assert.Equal(t, "resume data", info.ResumeData) iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "sa2", Output: &AgentOutput{ MessageOutput: &MessageVariant{Message: schema.UserMessage(info.ResumeData.(string))}, }, }) generator.Close() return iter }, } a, err := SetSubAgents(ctx, sa1, []Agent{sa2}) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: a, EnableStreaming: false, CheckPointStore: newMyStore(), }) iter := runner.Query(ctx, "", WithCheckPointID("1")) event, ok := iter.Next() assert.True(t, ok) assert.NotNil(t, event.Action.TransferToAgent) event, ok = iter.Next() assert.True(t, ok) assert.NotNil(t, event.Action.Interrupted) assert.Equal(t, 1, len(event.Action.Interrupted.InterruptContexts)) assert.Equal(t, "hello world", event.Action.Interrupted.InterruptContexts[0].Info) assert.True(t, event.Action.Interrupted.InterruptContexts[0].IsRootCause) assert.Equal(t, Address{ {Type: AddressSegmentAgent, ID: "sa1"}, {Type: AddressSegmentAgent, ID: "sa2"}, }, event.Action.Interrupted.InterruptContexts[0].Address) assert.NotEmpty(t, event.Action.Interrupted.InterruptContexts[0].ID) interruptID := event.Action.Interrupted.InterruptContexts[0].ID _, ok = iter.Next() assert.False(t, ok) iter, err = runner.ResumeWithParams(ctx, "1", &ResumeParams{ Targets: map[string]any{ interruptID: "resume data", }, }) assert.NoError(t, err) event, ok = iter.Next() assert.True(t, ok) assert.Equal(t, event.Output.MessageOutput.Message.Content, "resume data") _, ok = iter.Next() assert.False(t, ok) } func TestWorkflowInterrupt(t *testing.T) { ctx := context.Background() sa1 := &myAgent{ name: "sa1", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() intEvent := Interrupt(ctx, "sa1 interrupt data") intEvent.Action.Interrupted.Data = "sa1 interrupt data" generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.Equal(t, info.InterruptInfo.Data, "sa1 interrupt data") assert.True(t, info.WasInterrupted) assert.Nil(t, info.InterruptState) assert.True(t, info.IsResumeTarget) assert.Equal(t, "resume sa1", info.ResumeData) iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Close() return iter }, } // interrupt once sa2 := &myAgent{ name: "sa2", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() intEvent := StatefulInterrupt(ctx, "sa2 interrupt data", "sa2 interrupt") intEvent.Action.Interrupted.Data = "sa2 interrupt data" generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.Equal(t, info.InterruptInfo.Data, "sa2 interrupt data") assert.True(t, info.WasInterrupted) assert.NotNil(t, info.InterruptState) assert.Equal(t, "sa2 interrupt", info.InterruptState) assert.True(t, info.IsResumeTarget) assert.NotNil(t, info.ResumeData) assert.Equal(t, "resume sa2", info.ResumeData) iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Close() return iter }, } // interrupt once sa3 := &myAgent{ name: "sa3", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "sa3", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa3 completed"), }, }, }) generator.Close() return iter }, } // won't interrupt sa4 := &myAgent{ name: "sa4", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "sa4", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa4 completed"), }, }, }) generator.Close() return iter }, } // won't interrupt firstInterruptEvent := &AgentEvent{ AgentName: "sa1", RunPath: []RunStep{{"sequential"}, {"sa1"}}, Action: &AgentAction{ Interrupted: &InterruptInfo{ Data: &WorkflowInterruptInfo{ OrigInput: &AgentInput{ Messages: []Message{schema.UserMessage("hello world")}, }, SequentialInterruptIndex: 0, SequentialInterruptInfo: &InterruptInfo{ Data: "sa1 interrupt data", }, LoopIterations: 0, }, InterruptContexts: []*InterruptCtx{ { ID: "agent:sequential;agent:sa1", Info: "sa1 interrupt data", Address: Address{ { ID: "sequential", Type: AddressSegmentAgent, }, { ID: "sa1", Type: AddressSegmentAgent, }, }, IsRootCause: true, Parent: &InterruptCtx{ ID: "agent:sequential", Info: "Sequential workflow interrupted", Address: Address{ { ID: "sequential", Type: AddressSegmentAgent, }, }, }, }, }, }, }, } _ = firstInterruptEvent secondInterruptEvent := &AgentEvent{ AgentName: "sa2", RunPath: []RunStep{{"sequential"}, {"sa1"}, {"sa2"}}, Action: &AgentAction{ Interrupted: &InterruptInfo{ Data: &WorkflowInterruptInfo{ OrigInput: &AgentInput{ Messages: []Message{schema.UserMessage("hello world")}, }, SequentialInterruptIndex: 1, SequentialInterruptInfo: &InterruptInfo{ Data: "sa2 interrupt data", }, }, InterruptContexts: []*InterruptCtx{ { ID: "agent:sequential;agent:sa1;agent:sa2", Info: "sa2 interrupt data", Address: Address{ { ID: "sequential", Type: AddressSegmentAgent, }, { ID: "sa2", Type: AddressSegmentAgent, }, }, IsRootCause: true, Parent: &InterruptCtx{ ID: "agent:sequential", Info: "Sequential workflow interrupted", Address: Address{ { ID: "sequential", Type: AddressSegmentAgent, }, }, }, }, }, }, }, } _ = secondInterruptEvent messageEvents := []*AgentEvent{ { AgentName: "sa3", RunPath: []RunStep{{"sequential"}, {"sa1"}, {"sa2"}, {"sa3"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa3 completed"), }, }, }, { AgentName: "sa4", RunPath: []RunStep{{"sequential"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa4 completed"), }, }, }, } _ = messageEvents t.Run("test sequential workflow agent", func(t *testing.T) { // sequential a, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "sequential", Description: "sequential agent", SubAgents: []Agent{sa1, sa2, sa3, sa4}, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: a, CheckPointStore: newMyStore(), }) var events []*AgentEvent iter := runner.Query(ctx, "hello world", WithCheckPointID("sequential-1")) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 1, len(events)) assert.Equal(t, firstInterruptEvent.AgentName, events[0].AgentName) assert.Equal(t, firstInterruptEvent.RunPath, events[0].RunPath) assert.True(t, events[0].Action.Interrupted.InterruptContexts[0].EqualsWithoutID(firstInterruptEvent.Action.Interrupted.InterruptContexts[0])) interruptID1 := events[0].Action.Interrupted.InterruptContexts[0].ID events = []*AgentEvent{} // Resume after sa1 interrupt iter, err = runner.ResumeWithParams(ctx, "sequential-1", &ResumeParams{ Targets: map[string]any{ interruptID1: "resume sa1", }, }) assert.NoError(t, err) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 1, len(events)) assert.Equal(t, secondInterruptEvent.AgentName, events[0].AgentName) assert.Equal(t, secondInterruptEvent.RunPath, events[0].RunPath) assert.True(t, events[0].Action.Interrupted.InterruptContexts[0]. EqualsWithoutID(secondInterruptEvent.Action.Interrupted.InterruptContexts[0])) interruptID2 := events[0].Action.Interrupted.InterruptContexts[0].ID events = []*AgentEvent{} // Resume after sa2 interrupt iter, err = runner.ResumeWithParams(ctx, "sequential-1", &ResumeParams{ Targets: map[string]any{ interruptID2: "resume sa2", }, }) assert.NoError(t, err) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 2, len(events)) assert.Equal(t, messageEvents, events) }) t.Run("test loop workflow agent", func(t *testing.T) { // loop a, err := NewLoopAgent(ctx, &LoopAgentConfig{ Name: "loop", SubAgents: []Agent{sa1, sa2, sa3, sa4}, MaxIterations: 2, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: a, CheckPointStore: newMyStore(), }) var events []*AgentEvent iter := runner.Query(ctx, "hello world", WithCheckPointID("loop-1")) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } loopFirstInterruptEvent := &AgentEvent{ AgentName: "sa1", RunPath: []RunStep{{"loop"}, {"sa1"}}, Action: &AgentAction{ Interrupted: &InterruptInfo{ Data: &WorkflowInterruptInfo{ OrigInput: &AgentInput{ Messages: []Message{schema.UserMessage("hello world")}, }, SequentialInterruptIndex: 0, SequentialInterruptInfo: &InterruptInfo{ Data: "sa1 interrupt data", }, LoopIterations: 0, }, InterruptContexts: []*InterruptCtx{ { ID: "agent:loop;agent:sa1", Info: "sa1 interrupt data", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, { ID: "sa1", Type: AddressSegmentAgent, }, }, IsRootCause: true, Parent: &InterruptCtx{ ID: "agent:loop", Info: "Loop workflow interrupted", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, }, }, }, }, }, }, } assert.Equal(t, 1, len(events)) assert.Equal(t, loopFirstInterruptEvent.AgentName, events[0].AgentName) assert.Equal(t, loopFirstInterruptEvent.RunPath, events[0].RunPath) assert.True(t, events[0].Action.Interrupted.InterruptContexts[0].EqualsWithoutID(loopFirstInterruptEvent.Action.Interrupted.InterruptContexts[0])) loopInterruptID1 := events[0].Action.Interrupted.InterruptContexts[0].ID events = []*AgentEvent{} // Resume after sa1 interrupt iter, err = runner.ResumeWithParams(ctx, "loop-1", &ResumeParams{ Targets: map[string]any{ loopInterruptID1: "resume sa1", }, }) assert.NoError(t, err) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } loopSecondInterruptEvent := &AgentEvent{ AgentName: "sa2", RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}}, Action: &AgentAction{ Interrupted: &InterruptInfo{ Data: &WorkflowInterruptInfo{ OrigInput: &AgentInput{ Messages: []Message{schema.UserMessage("hello world")}, }, SequentialInterruptIndex: 1, SequentialInterruptInfo: &InterruptInfo{ Data: "sa2 interrupt data", }, LoopIterations: 0, }, InterruptContexts: []*InterruptCtx{ { ID: "agent:loop;agent:sa1;agent:sa2", Info: "sa2 interrupt data", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, { ID: "sa2", Type: AddressSegmentAgent, }, }, IsRootCause: true, Parent: &InterruptCtx{ ID: "agent:loop", Info: "Loop workflow interrupted", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, }, }, }, }, }, }, } assert.Equal(t, 1, len(events)) assert.Equal(t, loopSecondInterruptEvent.AgentName, events[0].AgentName) assert.Equal(t, loopSecondInterruptEvent.RunPath, events[0].RunPath) assert.True(t, events[0].Action.Interrupted.InterruptContexts[0].EqualsWithoutID(loopSecondInterruptEvent.Action.Interrupted.InterruptContexts[0])) loopInterruptID2 := events[0].Action.Interrupted.InterruptContexts[0].ID events = []*AgentEvent{} // Resume after sa2 interrupt iter, err = runner.ResumeWithParams(ctx, "loop-1", &ResumeParams{ Targets: map[string]any{ loopInterruptID2: "resume sa2", }, }) assert.NoError(t, err) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } loopThirdInterruptEvent := &AgentEvent{ AgentName: "sa1", RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}}, Action: &AgentAction{ Interrupted: &InterruptInfo{ Data: &WorkflowInterruptInfo{ OrigInput: &AgentInput{ Messages: []Message{schema.UserMessage("hello world")}, }, SequentialInterruptIndex: 0, SequentialInterruptInfo: &InterruptInfo{ Data: "sa1 interrupt data", }, LoopIterations: 1, }, InterruptContexts: []*InterruptCtx{ { ID: "agent:loop;agent:sa1;agent:sa2;agent:sa3;agent:sa4;agent:sa1", Info: "sa1 interrupt data", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, { ID: "sa1", Type: AddressSegmentAgent, }, }, IsRootCause: true, Parent: &InterruptCtx{ ID: "agent:loop", Info: "Loop workflow interrupted", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, }, }, }, }, }, }, } loopFourthInterruptEvent := &AgentEvent{ AgentName: "sa2", RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}, {"sa2"}}, Action: &AgentAction{ Interrupted: &InterruptInfo{ Data: &WorkflowInterruptInfo{ OrigInput: &AgentInput{ Messages: []Message{schema.UserMessage("hello world")}, }, SequentialInterruptIndex: 1, SequentialInterruptInfo: &InterruptInfo{ Data: "sa2 interrupt data", }, LoopIterations: 1, }, InterruptContexts: []*InterruptCtx{ { ID: "agent:loop;agent:sa1;agent:sa2;agent:sa3;agent:sa4;agent:sa1;agent:sa2", Info: "sa2 interrupt data", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, { ID: "sa2", Type: AddressSegmentAgent, }, }, IsRootCause: true, Parent: &InterruptCtx{ ID: "agent:loop", Info: "Loop workflow interrupted", Address: Address{ { ID: "loop", Type: AddressSegmentAgent, }, }, }, }, }, }, }, } loopMessageEvents := []*AgentEvent{ { AgentName: "sa3", RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa3 completed"), }, }, }, { AgentName: "sa4", RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa4 completed"), }, }, }, loopThirdInterruptEvent, } assert.Equal(t, 3, len(events)) // Check the first two message events assert.Equal(t, loopMessageEvents[0].AgentName, events[0].AgentName) assert.Equal(t, loopMessageEvents[0].RunPath, events[0].RunPath) assert.Equal(t, loopMessageEvents[0].Output.MessageOutput.Message.Content, events[0].Output.MessageOutput.Message.Content) assert.Equal(t, loopMessageEvents[1].AgentName, events[1].AgentName) assert.Equal(t, loopMessageEvents[1].RunPath, events[1].RunPath) assert.Equal(t, loopMessageEvents[1].Output.MessageOutput.Message.Content, events[1].Output.MessageOutput.Message.Content) // Check the third interrupt event using EqualsWithoutID assert.Equal(t, loopMessageEvents[2].AgentName, events[2].AgentName) assert.Equal(t, loopMessageEvents[2].RunPath, events[2].RunPath) assert.True(t, events[2].Action.Interrupted.InterruptContexts[0].EqualsWithoutID(loopMessageEvents[2].Action.Interrupted.InterruptContexts[0])) loopInterruptID3 := events[2].Action.Interrupted.InterruptContexts[0].ID events = []*AgentEvent{} // Resume after third interrupt iter, err = runner.ResumeWithParams(ctx, "loop-1", &ResumeParams{ Targets: map[string]any{ loopInterruptID3: "resume sa1", }, }) assert.NoError(t, err) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 1, len(events)) assert.Equal(t, loopFourthInterruptEvent.AgentName, events[0].AgentName) assert.Equal(t, loopFourthInterruptEvent.RunPath, events[0].RunPath) assert.True(t, events[0].Action.Interrupted.InterruptContexts[0].EqualsWithoutID(loopFourthInterruptEvent.Action.Interrupted.InterruptContexts[0])) loopInterruptID4 := events[0].Action.Interrupted.InterruptContexts[0].ID events = []*AgentEvent{} // Resume after fourth interrupt iter, err = runner.ResumeWithParams(ctx, "loop-1", &ResumeParams{ Targets: map[string]any{ loopInterruptID4: "resume sa2", }, }) assert.NoError(t, err) for { event, ok := iter.Next() if !ok { break } events = append(events, event) } loopFinalMessageEvents := []*AgentEvent{ { AgentName: "sa3", RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}, {"sa2"}, {"sa3"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa3 completed"), }, }, }, { AgentName: "sa4", RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa4 completed"), }, }, }, } assert.Equal(t, 2, len(events)) assert.Equal(t, loopFinalMessageEvents, events) }) t.Run("test parallel workflow agent", func(t *testing.T) { // parallel a, err := NewParallelAgent(ctx, &ParallelAgentConfig{ Name: "parallel agent", SubAgents: []Agent{sa1, sa2, sa3, sa4}, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: a, CheckPointStore: newMyStore(), }) iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) var ( events []*AgentEvent interruptEvent *AgentEvent ) for { event, ok := iter.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event continue } events = append(events, event) } assert.Equal(t, 2, len(events)) // Debug: Print actual events to see what we're getting for i, event := range events { t.Logf("Event %d: AgentName=%s, RunPath=%v, Output=%v", i, event.AgentName, event.RunPath, event.Output) } // Define parallel message events separately parallelMessageEvents := []*AgentEvent{ { AgentName: "sa4", RunPath: []RunStep{{"parallel agent"}, {"sa4"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa4 completed"), }, }, }, { AgentName: "sa3", RunPath: []RunStep{{"parallel agent"}, {"sa3"}}, Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa3 completed"), }, }, }, } assert.Contains(t, events, parallelMessageEvents[0]) assert.Contains(t, events, parallelMessageEvents[1]) assert.NotNil(t, interruptEvent) assert.Equal(t, "parallel agent", interruptEvent.AgentName) assert.Equal(t, []RunStep{{"parallel agent"}}, interruptEvent.RunPath) assert.NotNil(t, interruptEvent.Action.Interrupted) wii, ok := interruptEvent.Action.Interrupted.Data.(*WorkflowInterruptInfo) assert.True(t, ok) assert.Equal(t, 2, len(wii.ParallelInterruptInfo)) var sa1Found, sa2Found bool for _, info := range wii.ParallelInterruptInfo { switch info.Data { case "sa1 interrupt data": sa1Found = true case "sa2 interrupt data": sa2Found = true } } assert.True(t, sa1Found) assert.True(t, sa2Found) var sa1InfoFound, sa2InfoFound bool for _, ctx := range interruptEvent.Action.Interrupted.InterruptContexts { if ctx.Info == "sa1 interrupt data" { sa1InfoFound = true } else if ctx.Info == "sa2 interrupt data" { sa2InfoFound = true } } assert.Equal(t, 2, len(interruptEvent.Action.Interrupted.InterruptContexts)) assert.True(t, sa1InfoFound) assert.True(t, sa2InfoFound) var parallelInterruptID1, parallelInterruptID2 string for _, ctx := range interruptEvent.Action.Interrupted.InterruptContexts { if ctx.Info == "sa1 interrupt data" { parallelInterruptID1 = ctx.ID } else if ctx.Info == "sa2 interrupt data" { parallelInterruptID2 = ctx.ID } } assert.NotEmpty(t, parallelInterruptID1) assert.NotEmpty(t, parallelInterruptID2) iter, err = runner.ResumeWithParams(ctx, "1", &ResumeParams{ Targets: map[string]any{ parallelInterruptID1: "resume sa1", parallelInterruptID2: "resume sa2", }, }) assert.NoError(t, err) _, ok = iter.Next() assert.False(t, ok) }) } func TestChatModelInterrupt(t *testing.T) { ctx := context.Background() a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "name", Description: "description", Instruction: "instruction", Model: &myModel{ validator: func(i int, messages []*schema.Message) bool { if i > 0 && (len(messages) != 4 || messages[2].Content != "new user message") { return false } return true }, messages: []*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ { ID: "1", Function: schema.FunctionCall{ Name: "tool1", Arguments: "arguments", }, }, }), schema.AssistantMessage("completed", nil), }, }, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{&myTool1{}}, }, }, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: a, CheckPointStore: newMyStore(), }) iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) event, ok := iter.Next() assert.True(t, ok) event, ok = iter.Next() assert.True(t, ok) assert.NoError(t, event.Err) assert.NotNil(t, event.Action.Interrupted) assert.Equal(t, 1, len(event.Action.Interrupted.InterruptContexts)) assert.Equal(t, Address{ {Type: AddressSegmentAgent, ID: "name"}, {Type: AddressSegmentTool, ID: "tool1", SubID: "1"}, }, event.Action.Interrupted.InterruptContexts[0].Address) var ( chatModelAgentID string toolID string ) intCtx := event.Action.Interrupted.InterruptContexts[0] for intCtx != nil { if intCtx.Address[len(intCtx.Address)-1].Type == AddressSegmentTool { toolID = intCtx.ID } else if intCtx.Address[len(intCtx.Address)-1].Type == AddressSegmentAgent { chatModelAgentID = intCtx.ID } intCtx = intCtx.Parent } event, ok = iter.Next() assert.False(t, ok) iter, err = runner.ResumeWithParams(ctx, "1", &ResumeParams{ Targets: map[string]any{ chatModelAgentID: &ChatModelAgentResumeData{ HistoryModifier: func(ctx context.Context, history []Message) []Message { history[2].Content = "new user message" return history }, }, toolID: "tool resume result", }, }) assert.NoError(t, err) event, ok = iter.Next() assert.True(t, ok) assert.NoError(t, event.Err) assert.Equal(t, event.Output.MessageOutput.Message.Content, "tool resume result") event, ok = iter.Next() assert.True(t, ok) assert.NoError(t, event.Err) assert.Equal(t, event.Output.MessageOutput.Message.Content, "completed") } func TestChatModelAgentToolInterrupt(t *testing.T) { sa := &myAgent{ runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() intAct := Interrupt(ctx, "hello world") intAct.Action.Interrupted.Data = "hello world" generator.Send(intAct) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.NotNil(t, info) assert.False(t, info.EnableStreaming) if !info.IsResumeTarget { iter, generator := NewAsyncIteratorPair[*AgentEvent]() intAct := Interrupt(ctx, "interrupt again") intAct.Action.Interrupted.Data = "interrupt again" generator.Send(intAct) generator.Close() return iter } assert.NotNil(t, info.ResumeData) assert.Equal(t, "resume sa", info.ResumeData) iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.UserMessage(fmt.Sprintf("my agent completed with data %s", info.ResumeData))}}}) generator.Close() return iter }, } ctx := context.Background() a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "name", Description: "description", Instruction: "instruction", Model: &myModel{ messages: []*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ { ID: "1", Function: schema.FunctionCall{ Name: "myAgent", Arguments: "{\"request\":\"123\"}", }, }, }), schema.AssistantMessage("completed", nil), }, }, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{NewAgentTool(ctx, sa)}, }, }, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: a, CheckPointStore: newMyStore(), }) iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) event, ok := iter.Next() assert.True(t, ok) event, ok = iter.Next() assert.True(t, ok) assert.NoError(t, event.Err) assert.NotNil(t, event.Action.Interrupted) event, ok = iter.Next() assert.False(t, ok) iter, err = runner.Resume(ctx, "1") assert.NoError(t, err) event, ok = iter.Next() assert.True(t, ok) assert.NoError(t, event.Err) assert.NotNil(t, event.Action.Interrupted) assert.Equal(t, 1, len(event.Action.Interrupted.InterruptContexts)) for _, ctx := range event.Action.Interrupted.InterruptContexts { if ctx.IsRootCause { assert.Equal(t, Address{ {Type: AddressSegmentAgent, ID: "name"}, {Type: AddressSegmentTool, ID: "myAgent", SubID: "1"}, {Type: AddressSegmentAgent, ID: "myAgent"}, }, ctx.Address) assert.Equal(t, "interrupt again", ctx.Info) } } var toolInterruptID string for _, ctx := range event.Action.Interrupted.InterruptContexts { if ctx.IsRootCause { toolInterruptID = ctx.ID break } } assert.NotEmpty(t, toolInterruptID) event, ok = iter.Next() assert.False(t, ok) iter, err = runner.ResumeWithParams(ctx, "1", &ResumeParams{ Targets: map[string]any{ toolInterruptID: "resume sa", }, }) assert.NoError(t, err) event, ok = iter.Next() assert.True(t, ok) assert.NoError(t, event.Err) assert.Equal(t, event.Output.MessageOutput.Message.Content, "my agent completed with data resume sa") event, ok = iter.Next() assert.True(t, ok) assert.NoError(t, event.Err) assert.Equal(t, event.Output.MessageOutput.Message.Content, "completed") _, ok = iter.Next() assert.False(t, ok) } func newMyStore() *myStore { return &myStore{ m: map[string][]byte{}, } } type myStore struct { m map[string][]byte } func (m *myStore) Set(_ context.Context, key string, value []byte) error { m.m[key] = value return nil } func (m *myStore) Get(_ context.Context, key string) ([]byte, bool, error) { v, ok := m.m[key] return v, ok, nil } type myAgentOptions struct { interrupt bool value string } func withValue(value string) AgentRunOption { return WrapImplSpecificOptFn(func(t *myAgentOptions) { t.value = value }) } type myAgent struct { name string runFn func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] resumeFn func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } func (m *myAgent) Name(_ context.Context) string { if len(m.name) > 0 { return m.name } return "myAgent" } func (m *myAgent) Description(_ context.Context) string { return "myAgent description" } func (m *myAgent) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { return m.runFn(ctx, input, options...) } func (m *myAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { return m.resumeFn(ctx, info, opts...) } type myModel struct { times int messages []*schema.Message validator func(int, []*schema.Message) bool } func (m *myModel) Generate(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { if m.validator != nil && !m.validator(m.times, input) { return nil, errors.New("invalid input") } if m.times >= len(m.messages) { return nil, errors.New("exceeded max number of messages") } t := m.times m.times++ return m.messages[t], nil } func (m *myModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { panic("implement me") } func (m *myModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { return m, nil } type myTool1 struct{} func (m *myTool1) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "tool1", Desc: "desc", }, nil } func (m *myTool1) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { if wasInterrupted, _, _ := tool.GetInterruptState[any](ctx); !wasInterrupted { return "", tool.Interrupt(ctx, nil) } if isResumeFlow, hasResumeData, data := tool.GetResumeContext[string](ctx); !isResumeFlow { return "", tool.Interrupt(ctx, nil) } else if hasResumeData { return data, nil } return "result", nil } func TestCyclicalAgentInterrupt(t *testing.T) { ctx := context.Background() var agentA, agentB, agentC Agent // agentC interrupts agentC = &myAgent{ name: "C", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() intAct := Interrupt(ctx, "interrupt from C") generator.Send(intAct) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.True(t, info.IsResumeTarget) assert.NotNil(t, info.ResumeData) assert.Equal(t, "resume C", info.ResumeData) iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "C", Output: &AgentOutput{ MessageOutput: &MessageVariant{Message: schema.UserMessage("C completed")}, }, }) generator.Close() return iter }, } // agentB transfers back to its parent A agentB = &myAgent{ name: "B", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "B", Action: &AgentAction{ TransferToAgent: &TransferToAgentAction{ DestAgentName: "A", // Transfer back to parent }, }, }) generator.Close() return iter }, } // agentA is the parent, orchestrating the A->B->A->C flow agentA = &myAgent{ name: "A", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { runCtx := getRunCtx(ctx) iter, generator := NewAsyncIteratorPair[*AgentEvent]() // If the last agent was B, we are in the A->B->A path, so transfer to C. // Otherwise, it's the first run, transfer to B. dest := "B" if len(runCtx.RunPath) > 1 && runCtx.RunPath[len(runCtx.RunPath)-2].agentName == "B" { dest = "C" } generator.Send(&AgentEvent{ AgentName: "A", Action: &AgentAction{ TransferToAgent: &TransferToAgentAction{ DestAgentName: dest, }, }, }) generator.Close() return iter }, } // Set up the hierarchy: A is parent of B and C. agentA, err := SetSubAgents(ctx, agentA, []Agent{agentB, agentC}) assert.NoError(t, err) // Run the test runner := NewRunner(ctx, RunnerConfig{ Agent: agentA, CheckPointStore: newMyStore(), }) iter := runner.Query(ctx, "start", WithCheckPointID("cyclical-1")) var events []*AgentEvent for { event, ok := iter.Next() if !ok { break } events = append(events, event) } // We expect 3 transfer events (A->B, B->A, A->C) and 1 interrupt event from C. assert.Equal(t, 4, len(events)) interruptEvent := events[3] assert.NotNil(t, interruptEvent.Action.Interrupted) assert.Equal(t, "C", interruptEvent.AgentName) // Check the interrupt context assert.Equal(t, 1, len(interruptEvent.Action.Interrupted.InterruptContexts)) interruptCtx := interruptEvent.Action.Interrupted.InterruptContexts[0] assert.True(t, interruptCtx.IsRootCause) assert.Equal(t, "interrupt from C", interruptCtx.Info) expectedAddr := Address{ {Type: AddressSegmentAgent, ID: "A"}, {Type: AddressSegmentAgent, ID: "B"}, {Type: AddressSegmentAgent, ID: "A"}, {Type: AddressSegmentAgent, ID: "C"}, } assert.Equal(t, expectedAddr, interruptCtx.Address) assert.NotEmpty(t, interruptCtx.ID) // Resume the execution iter, err = runner.ResumeWithParams(ctx, "cyclical-1", &ResumeParams{ Targets: map[string]any{ interruptCtx.ID: "resume C", }, }) assert.NoError(t, err) events = []*AgentEvent{} for { event, ok := iter.Next() if !ok { break } events = append(events, event) } // We expect one output event from C assert.Equal(t, 1, len(events)) assert.Equal(t, "C completed", events[0].Output.MessageOutput.Message.Content) } // myStatefulTool is a tool that can interrupt and has internal state to track invocations. type myStatefulTool struct { name string t *testing.T } func (m *myStatefulTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: m.name, Desc: "desc", }, nil } type myStatefulToolState struct { InterruptCount int } func init() { schema.Register[myStatefulToolState]() } func (m *myStatefulTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { wasInterrupted, hasState, state := tool.GetInterruptState[myStatefulToolState](ctx) if !wasInterrupted { return "", tool.StatefulInterrupt(ctx, fmt.Sprintf("interrupt from %s", m.name), myStatefulToolState{InterruptCount: 1}) } isResumeFlow, hasResumeData, data := tool.GetResumeContext[string](ctx) if !isResumeFlow || !hasResumeData { assert.True(m.t, hasState, "tool %s should have interrupt state on resume", m.name) return "", tool.StatefulInterrupt(ctx, fmt.Sprintf("interrupt from %s", m.name), myStatefulToolState{InterruptCount: state.InterruptCount + 1}) } return data, nil } func TestChatModelParallelToolInterruptAndResume(t *testing.T) { ctx := context.Background() toolA := &myStatefulTool{name: "toolA", t: t} toolB := &myStatefulTool{name: "toolB", t: t} chatModel, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "ParallelToolAgent", Description: "An agent that uses parallel tools", Model: &myModel{ messages: []*schema.Message{ // 1. First model response: call toolA and toolB in parallel schema.AssistantMessage("", []schema.ToolCall{ {ID: "1", Function: schema.FunctionCall{Name: "toolA", Arguments: "{}"}}, {ID: "2", Function: schema.FunctionCall{Name: "toolB", Arguments: "{}"}}, }), // 2. Second model response (after tools are resumed): call them again to check state schema.AssistantMessage("", []schema.ToolCall{ {ID: "3", Function: schema.FunctionCall{Name: "toolA", Arguments: "{}"}}, {ID: "4", Function: schema.FunctionCall{Name: "toolB", Arguments: "{}"}}, }), // 3. Final completion schema.AssistantMessage("all done", nil), }, }, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{toolA, toolB}, }, }, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: chatModel, CheckPointStore: newMyStore(), }) // 1. Initial query -> parallel interrupt from toolA and toolB iter := runner.Query(ctx, "start", WithCheckPointID("parallel-tool-test-1")) normalEvents, interruptEvent := consumeUntilInterrupt(iter) assert.Equal(t, 1, len(normalEvents)) assert.NotNil(t, interruptEvent) assert.Equal(t, 2, len(interruptEvent.Action.Interrupted.InterruptContexts), "should have 2 interrupts") var toolAInterruptID, toolBInterruptID string for _, info := range interruptEvent.Action.Interrupted.InterruptContexts { if info.Info == "interrupt from toolA" { toolAInterruptID = info.ID assert.True(t, info.IsRootCause) } else if info.Info == "interrupt from toolB" { toolBInterruptID = info.ID assert.True(t, info.IsRootCause) } } assert.NotEmpty(t, toolAInterruptID) assert.NotEmpty(t, toolBInterruptID) // 2. Resume, targeting only toolA. toolB should re-interrupt. iter, err = runner.ResumeWithParams(ctx, "parallel-tool-test-1", &ResumeParams{ Targets: map[string]any{ toolAInterruptID: "toolA resumed", }, }) assert.NoError(t, err) _, interruptEvent = consumeUntilInterrupt(iter) assert.NotNil(t, interruptEvent, "expected a re-interrupt from toolB") assert.Equal(t, 1, len(interruptEvent.Action.Interrupted.InterruptContexts), "should have 1 remaining interrupts") var rootCause *InterruptCtx for _, info := range interruptEvent.Action.Interrupted.InterruptContexts { if info.IsRootCause { rootCause = info break } } if rootCause == nil { t.Fatal("expected a root cause interrupt from toolB") } assert.Equal(t, "interrupt from toolB", rootCause.Info) toolBReInterruptID := rootCause.ID // 3. Resume the re-interrupted toolB. The agent should then call the tools again. iter, err = runner.ResumeWithParams(ctx, "parallel-tool-test-1", &ResumeParams{ Targets: map[string]any{ toolBReInterruptID: "toolB resumed", }, }) assert.NoError(t, err) // 4. Consume all final events. The internal assertions in the tools will check the wasInterrupted flag. // We expect to see the results of the second tool calls, and then the final agent completion. finalEvents, interruptEvent := consumeUntilInterrupt(iter) assert.Equal(t, 2, len(finalEvents)) assert.NotNil(t, interruptEvent) } // TestNestedChatModelAgentWithAgentTool verifies that the shouldFire method correctly prevents // duplicate event firing in nested ChatModelAgent scenarios (ChatModelAgent -> AgentTool -> ChatModelAgent). // This ensures that only the inner agent's cbHandler fires, not the outer agent's. func TestNestedChatModelAgentWithAgentTool(t *testing.T) { ctx := context.Background() // Create an interruptible tool for the inner agent innerTool := &myStatefulTool{name: "innerTool", t: t} // Create the inner ChatModelAgent that will be wrapped by AgentTool innerAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "InnerAgent", Description: "Inner agent with interruptible tool", Model: &myModel{ messages: []*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ {ID: "1", Function: schema.FunctionCall{Name: "innerTool", Arguments: "{}"}}, }), schema.AssistantMessage("inner agent completed", nil), }, }, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{innerTool}, }, }, }) assert.NoError(t, err) // Wrap the inner agent in an AgentTool agentTool := NewAgentTool(ctx, innerAgent) // Create the outer ChatModelAgent that uses the AgentTool outerAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "OuterAgent", Description: "Outer agent with AgentTool containing inner agent", Model: &myModel{ messages: []*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ {ID: "1", Function: schema.FunctionCall{Name: "InnerAgent", Arguments: "{}"}}, }), schema.AssistantMessage("outer agent completed", nil), }, }, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{agentTool}, }, }, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: outerAgent, CheckPointStore: newMyStore(), }) // Run the query - this should trigger the nested agent structure iter := runner.Query(ctx, "start", WithCheckPointID("nested-agent-test-1")) // Collect all events to verify no duplicates var allEvents []*AgentEvent var interruptEvent *AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { assert.Nil(t, interruptEvent) interruptEvent = event } allEvents = append(allEvents, event) } if interruptEvent == nil { t.Fatal("expected an interrupt event") } // Verify we got exactly one interrupt event (not duplicated) assert.NotNil(t, interruptEvent, "should have an interrupt event") assert.Equal(t, 1, len(interruptEvent.Action.Interrupted.InterruptContexts), "should have exactly one interrupt context") // Verify the interrupt comes from the inner tool, not duplicated interruptCtx := interruptEvent.Action.Interrupted.InterruptContexts[0] assert.True(t, interruptCtx.IsRootCause, "interrupt should be root cause") assert.Equal(t, "interrupt from innerTool", interruptCtx.Info) // Verify the address path shows the correct nested structure expectedAddress := Address{ {Type: AddressSegmentAgent, ID: "OuterAgent"}, {Type: AddressSegmentTool, ID: "InnerAgent", SubID: "1"}, {Type: AddressSegmentAgent, ID: "InnerAgent"}, {Type: AddressSegmentTool, ID: "innerTool", SubID: "1"}, } assert.Equal(t, expectedAddress, interruptCtx.Address, "interrupt address should show correct nested structure") // Verify no duplicate events by checking agent names in events var agentNames []string for _, event := range allEvents { if event.AgentName != "" { agentNames = append(agentNames, event.AgentName) } } // Should only have events from the outer agent (the inner agent's events should be handled // by the AgentTool and not duplicated by the outer agent's cbHandler) for _, name := range agentNames { assert.Equal(t, "OuterAgent", name, "all events should come from OuterAgent, not duplicated from InnerAgent") } // Now resume the interrupt interruptID := interruptCtx.ID iter, err = runner.ResumeWithParams(ctx, "nested-agent-test-1", &ResumeParams{ Targets: map[string]any{ interruptID: "resume inner tool", }, }) assert.NoError(t, err) // Collect final events after resume var finalEvents []*AgentEvent for { event, ok := iter.Next() if !ok { break } finalEvents = append(finalEvents, event) } // Verify completion events assert.Greater(t, len(finalEvents), 0, "should have completion events after resume") // Check that we get the expected completion messages var foundInnerCompletion, foundOuterCompletion bool for _, event := range finalEvents { if event.Output != nil && event.Output.MessageOutput != nil { if event.Output.MessageOutput.Message != nil { content := event.Output.MessageOutput.Message.Content if content == "inner agent completed" { foundInnerCompletion = true } else if content == "outer agent completed" { foundOuterCompletion = true } } } } assert.True(t, foundInnerCompletion, "should have inner agent completion") assert.True(t, foundOuterCompletion, "should have outer agent completion") } // consumeUntilInterrupt consumes events from the iterator until an interrupt is found or it's exhausted. func consumeUntilInterrupt(iter *AsyncIterator[*AgentEvent]) (normalEvents []*AgentEvent, interruptEvent *AgentEvent) { for { event, ok := iter.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event continue } normalEvents = append(normalEvents, event) } return } type returnDirectlyTool struct { name string } func (t *returnDirectlyTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: t.name, Desc: "A tool that returns directly", }, nil } func (t *returnDirectlyTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return "return directly result", nil } type interruptingTool struct { name string } func (i *interruptingTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: i.name, Desc: "A tool that interrupts", }, nil } func (i *interruptingTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { if wasInterrupted, _, _ := compose.GetInterruptState[any](ctx); !wasInterrupted { return "", compose.Interrupt(ctx, "interrupt data") } if isResumeFlow, hasResumeData, data := compose.GetResumeContext[string](ctx); isResumeFlow && hasResumeData { return data, nil } return "resumed without data", nil } type twoToolCallModel struct { returnDirectlyToolName string interruptingToolName string callCount int receivedTools []*schema.ToolInfo mu sync.Mutex } func (m *twoToolCallModel) Generate(_ context.Context, _ []*schema.Message, opts ...model.Option) (*schema.Message, error) { m.mu.Lock() m.callCount++ callNum := m.callCount options := model.GetCommonOptions(&model.Options{}, opts...) if options.Tools != nil { m.receivedTools = options.Tools } m.mu.Unlock() if callNum == 1 { return &schema.Message{ Role: schema.Assistant, Content: "", ToolCalls: []schema.ToolCall{ { ID: "call_return_directly", Type: "function", Function: schema.FunctionCall{ Name: m.returnDirectlyToolName, Arguments: "{}", }, }, { ID: "call_interrupting", Type: "function", Function: schema.FunctionCall{ Name: m.interruptingToolName, Arguments: "{}", }, }, }, }, nil } return schema.AssistantMessage("final response", nil), nil } func (m *twoToolCallModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { panic("not implemented") } func (m *twoToolCallModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { return m, nil } func (m *twoToolCallModel) GetReceivedTools() []*schema.ToolInfo { m.mu.Lock() defer m.mu.Unlock() return m.receivedTools } type dynamicTool struct { name string } func (t *dynamicTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: t.name, Desc: "A dynamically added tool", }, nil } func (t *dynamicTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return "dynamic tool result", nil } func TestReturnDirectlyEventSentAfterResume(t *testing.T) { ctx := context.Background() returnDirectlyToolName := "return_directly_tool" interruptingToolName := "interrupting_tool" dynamicToolName := "dynamic_tool" mdl := &twoToolCallModel{ returnDirectlyToolName: returnDirectlyToolName, interruptingToolName: interruptingToolName, } agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent for return directly + interrupt", Model: mdl, ToolsConfig: ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{ &returnDirectlyTool{name: returnDirectlyToolName}, &interruptingTool{name: interruptingToolName}, }, }, ReturnDirectly: map[string]bool{ returnDirectlyToolName: true, }, }, Handlers: []ChatModelAgentMiddleware{ &interruptTestToolsHandler{tools: []tool.BaseTool{&dynamicTool{name: dynamicToolName}}}, }, }) assert.NoError(t, err) store := newMyStore() runner := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: false, CheckPointStore: store, }) iter := runner.Query(ctx, "test input", WithCheckPointID("test_checkpoint")) var interruptEvent *AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event } } assert.NotNil(t, interruptEvent, "Should have an interrupt event") assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts) receivedToolsBeforeResume := mdl.GetReceivedTools() var hasDynamicToolBeforeResume bool for _, ti := range receivedToolsBeforeResume { if ti.Name == dynamicToolName { hasDynamicToolBeforeResume = true } } assert.True(t, hasDynamicToolBeforeResume, "Dynamic tool should be in tool list before interrupt") interruptID := interruptEvent.Action.Interrupted.InterruptContexts[0].ID resumeIter, err := runner.ResumeWithParams(ctx, "test_checkpoint", &ResumeParams{ Targets: map[string]any{ interruptID: "resume data", }, }) assert.NoError(t, err) var resumeEvents []*AgentEvent for { event, ok := resumeIter.Next() if !ok { break } resumeEvents = append(resumeEvents, event) } var hasReturnDirectlyEvent bool for _, e := range resumeEvents { if e.Output != nil && e.Output.MessageOutput != nil { if e.Output.MessageOutput.Role == schema.Tool && e.Output.MessageOutput.ToolName == returnDirectlyToolName { hasReturnDirectlyEvent = true } } } assert.True(t, hasReturnDirectlyEvent, "ReturnDirectlyEvent should be sent after resume") receivedToolsAfterResume := mdl.GetReceivedTools() var hasDynamicToolAfterResume bool for _, ti := range receivedToolsAfterResume { if ti.Name == dynamicToolName { hasDynamicToolAfterResume = true } } assert.True(t, hasDynamicToolAfterResume, "Dynamic tool should be in tool list after resume (bc.toolUpdated path)") } ================================================ FILE: adk/middlewares/dynamictool/toolsearch/toolsearch.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package toolsearch provides tool search middleware. package toolsearch import ( "context" "encoding/json" "fmt" "regexp" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) // Config is the configuration for the tool search middleware. type Config struct { // DynamicTools is a list of tools that can be dynamically searched and loaded by the agent. DynamicTools []tool.BaseTool } // New constructs and returns the tool search middleware. // // The tool search middleware enables dynamic tool selection for agents with large tool libraries. // Instead of passing all tools to the model at once (which can overwhelm context limits), // this middleware: // // 1. Adds a "tool_search" meta-tool that accepts a regex pattern to search tool names // 2. Initially hides all dynamic tools from the model's tool list // 3. When the model calls tool_search, matching tools become available for subsequent calls // // Example usage: // // middleware, _ := toolsearch.New(ctx, &toolsearch.Config{ // DynamicTools: []tool.BaseTool{weatherTool, stockTool, currencyTool, ...}, // }) // agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ // // ... // Handlers: []adk.ChatModelAgentMiddleware{middleware}, // }) func New(ctx context.Context, config *Config) (adk.ChatModelAgentMiddleware, error) { if config == nil { return nil, fmt.Errorf("config is required") } if len(config.DynamicTools) == 0 { return nil, fmt.Errorf("tools is required") } return &middleware{ dynamicTools: config.DynamicTools, }, nil } type middleware struct { adk.BaseChatModelAgentMiddleware dynamicTools []tool.BaseTool } func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { if runCtx == nil { return ctx, runCtx, nil } nRunCtx := *runCtx toolNames, err := getToolNames(ctx, m.dynamicTools) if err != nil { return ctx, nil, fmt.Errorf("failed to get tool names: %w", err) } nRunCtx.Tools = append(nRunCtx.Tools, newToolSearchTool(toolNames)) nRunCtx.Tools = append(nRunCtx.Tools, m.dynamicTools...) return ctx, &nRunCtx, nil } func (m *middleware) WrapModel(_ context.Context, cm model.BaseChatModel, mc *adk.ModelContext) (model.BaseChatModel, error) { return &wrapper{allTools: mc.Tools, cm: cm, dynamicTools: m.dynamicTools}, nil } type wrapper struct { allTools []*schema.ToolInfo dynamicTools []tool.BaseTool cm model.BaseChatModel } func (w *wrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } return w.cm.Generate(ctx, input, append(opts, model.WithTools(tools))...) } func (w *wrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } return w.cm.Stream(ctx, input, append(opts, model.WithTools(tools))...) } func newToolSearchTool(toolNames []string) *toolSearchTool { return &toolSearchTool{toolNames: toolNames} } type toolSearchTool struct { toolNames []string } const ( toolSearchToolName = "tool_search" ) func (t *toolSearchTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "tool_search", Desc: "Search for tools using a regex pattern that matches tool names. Returns a list of matching tool names. Use this when you need a tool but don't have it available yet.", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "regex_pattern": { Type: schema.String, Desc: "A regex pattern to match tool names against.", Required: true, }, }), }, nil } type toolSearchArgs struct { RegexPattern string `json:"regex_pattern"` } type toolSearchResult struct { SelectedTools []string `json:"selectedTools"` } func (t *toolSearchTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { var args toolSearchArgs if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { return "", fmt.Errorf("failed to unmarshal tool search arguments: %w", err) } if args.RegexPattern == "" { return "", fmt.Errorf("regex_pattern is required") } re, err := regexp.Compile(args.RegexPattern) if err != nil { return "", fmt.Errorf("invalid regex pattern: %w", err) } var matchedTools []string for _, name := range t.toolNames { if re.MatchString(name) { matchedTools = append(matchedTools, name) } } result := toolSearchResult{ SelectedTools: matchedTools, } output, err := json.Marshal(result) if err != nil { return "", fmt.Errorf("failed to marshal result: %w", err) } return string(output), nil } func getToolNames(ctx context.Context, tools []tool.BaseTool) ([]string, error) { ret := make([]string, 0, len(tools)) for _, t := range tools { info, err := t.Info(ctx) if err != nil { return nil, err } ret = append(ret, info.Name) } return ret, nil } func extractSelectedTools(ctx context.Context, messages []*schema.Message) ([]string, error) { var selectedTools []string for _, message := range messages { if message.Role != schema.Tool || message.ToolName != toolSearchToolName { continue } result := &toolSearchResult{} err := json.Unmarshal([]byte(message.Content), result) if err != nil { return nil, fmt.Errorf("failed to unmarshal tool search tool result: %w", err) } selectedTools = append(selectedTools, result.SelectedTools...) } return selectedTools, nil } func invertSelect[T comparable](all []T, selected []T) map[T]struct{} { selectedSet := make(map[T]struct{}, len(selected)) for _, s := range selected { selectedSet[s] = struct{}{} } result := make(map[T]struct{}) for _, item := range all { if _, ok := selectedSet[item]; !ok { result[item] = struct{}{} } } return result } func removeTools(ctx context.Context, all []*schema.ToolInfo, dynamicTools []tool.BaseTool, messages []*schema.Message) ([]*schema.ToolInfo, error) { selectedToolNames, err := extractSelectedTools(ctx, messages) if err != nil { return nil, err } dynamicToolNames, err := getToolNames(ctx, dynamicTools) if err != nil { return nil, err } removeMap := invertSelect(dynamicToolNames, selectedToolNames) ret := make([]*schema.ToolInfo, 0, len(all)-len(dynamicTools)) for _, info := range all { if _, ok := removeMap[info.Name]; ok { continue } ret = append(ret, info) } return ret, nil } ================================================ FILE: adk/middlewares/dynamictool/toolsearch/toolsearch_test.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package toolsearch import ( "context" "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) type mockTool struct { name string desc string } func (m *mockTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: m.name, Desc: m.desc, }, nil } func newMockTool(name, desc string) *mockTool { return &mockTool{name: name, desc: desc} } func TestNew(t *testing.T) { ctx := context.Background() t.Run("nil config returns error", func(t *testing.T) { m, err := New(ctx, nil) assert.Nil(t, m) assert.Error(t, err) assert.Contains(t, err.Error(), "config is required") }) t.Run("empty tools returns error", func(t *testing.T) { m, err := New(ctx, &Config{DynamicTools: []tool.BaseTool{}}) assert.Nil(t, m) assert.Error(t, err) assert.Contains(t, err.Error(), "tools is required") }) t.Run("valid config returns middleware", func(t *testing.T) { tools := []tool.BaseTool{ newMockTool("tool1", "desc1"), newMockTool("tool2", "desc2"), } m, err := New(ctx, &Config{DynamicTools: tools}) assert.NoError(t, err) assert.NotNil(t, m) }) } func TestMiddleware_BeforeAgent(t *testing.T) { ctx := context.Background() t.Run("nil runCtx returns nil", func(t *testing.T) { tools := []tool.BaseTool{newMockTool("tool1", "desc1")} m, err := New(ctx, &Config{DynamicTools: tools}) require.NoError(t, err) newCtx, newRunCtx, err := m.BeforeAgent(ctx, nil) assert.NoError(t, err) assert.Equal(t, ctx, newCtx) assert.Nil(t, newRunCtx) }) t.Run("adds tool_search and dynamic tools", func(t *testing.T) { tools := []tool.BaseTool{ newMockTool("tool1", "desc1"), newMockTool("tool2", "desc2"), } m, err := New(ctx, &Config{DynamicTools: tools}) require.NoError(t, err) middleware := m.(*middleware) runCtx := &adk.ChatModelAgentContext{ Tools: []tool.BaseTool{}, } _, newRunCtx, err := middleware.BeforeAgent(ctx, runCtx) assert.NoError(t, err) assert.NotNil(t, newRunCtx) assert.Len(t, newRunCtx.Tools, 3) }) } func TestToolSearchTool_Info(t *testing.T) { ctx := context.Background() toolNames := []string{"tool1", "tool2", "tool3"} tst := newToolSearchTool(toolNames) info, err := tst.Info(ctx) assert.NoError(t, err) assert.Equal(t, "tool_search", info.Name) assert.Contains(t, info.Desc, "regex pattern") assert.NotNil(t, info.ParamsOneOf) } func TestToolSearchTool_InvokableRun(t *testing.T) { ctx := context.Background() toolNames := []string{"get_weather", "get_time", "search_web", "calculate_sum"} tst := newToolSearchTool(toolNames) t.Run("empty regex pattern returns error", func(t *testing.T) { args := `{"regex_pattern": ""}` result, err := tst.InvokableRun(ctx, args) assert.Error(t, err) assert.Contains(t, err.Error(), "regex_pattern is required") assert.Empty(t, result) }) t.Run("invalid json returns error", func(t *testing.T) { args := `{invalid json}` result, err := tst.InvokableRun(ctx, args) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to unmarshal") assert.Empty(t, result) }) t.Run("invalid regex returns error", func(t *testing.T) { args := `{"regex_pattern": "[invalid"}` result, err := tst.InvokableRun(ctx, args) assert.Error(t, err) assert.Contains(t, err.Error(), "invalid regex pattern") assert.Empty(t, result) }) t.Run("matches tools with prefix pattern", func(t *testing.T) { args := `{"regex_pattern": "^get_"}` result, err := tst.InvokableRun(ctx, args) assert.NoError(t, err) var res toolSearchResult err = json.Unmarshal([]byte(result), &res) assert.NoError(t, err) assert.ElementsMatch(t, []string{"get_weather", "get_time"}, res.SelectedTools) }) t.Run("matches tools with suffix pattern", func(t *testing.T) { args := `{"regex_pattern": "_sum$"}` result, err := tst.InvokableRun(ctx, args) assert.NoError(t, err) var res toolSearchResult err = json.Unmarshal([]byte(result), &res) assert.NoError(t, err) assert.ElementsMatch(t, []string{"calculate_sum"}, res.SelectedTools) }) t.Run("matches all tools with wildcard", func(t *testing.T) { args := `{"regex_pattern": ".*"}` result, err := tst.InvokableRun(ctx, args) assert.NoError(t, err) var res toolSearchResult err = json.Unmarshal([]byte(result), &res) assert.NoError(t, err) assert.ElementsMatch(t, toolNames, res.SelectedTools) }) t.Run("no matches returns empty list", func(t *testing.T) { args := `{"regex_pattern": "^nonexistent_"}` result, err := tst.InvokableRun(ctx, args) assert.NoError(t, err) var res toolSearchResult err = json.Unmarshal([]byte(result), &res) assert.NoError(t, err) assert.Empty(t, res.SelectedTools) }) } func TestGetToolNames(t *testing.T) { ctx := context.Background() t.Run("returns tool names", func(t *testing.T) { tools := []tool.BaseTool{ newMockTool("tool1", "desc1"), newMockTool("tool2", "desc2"), newMockTool("tool3", "desc3"), } names, err := getToolNames(ctx, tools) assert.NoError(t, err) assert.Equal(t, []string{"tool1", "tool2", "tool3"}, names) }) t.Run("empty tools returns empty slice", func(t *testing.T) { names, err := getToolNames(ctx, []tool.BaseTool{}) assert.NoError(t, err) assert.Empty(t, names) }) } func TestExtractSelectedTools(t *testing.T) { ctx := context.Background() t.Run("extracts selected tools from messages", func(t *testing.T) { result := toolSearchResult{SelectedTools: []string{"tool1", "tool2"}} resultJSON, _ := json.Marshal(result) messages := []*schema.Message{ schema.UserMessage("hello"), {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, } selected, err := extractSelectedTools(ctx, messages) assert.NoError(t, err) assert.ElementsMatch(t, []string{"tool1", "tool2"}, selected) }) t.Run("handles multiple tool_search results", func(t *testing.T) { result1 := toolSearchResult{SelectedTools: []string{"tool1"}} result1JSON, _ := json.Marshal(result1) result2 := toolSearchResult{SelectedTools: []string{"tool2", "tool3"}} result2JSON, _ := json.Marshal(result2) messages := []*schema.Message{ {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result1JSON)}, schema.UserMessage("continue"), {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result2JSON)}, } selected, err := extractSelectedTools(ctx, messages) assert.NoError(t, err) assert.ElementsMatch(t, []string{"tool1", "tool2", "tool3"}, selected) }) t.Run("ignores non-tool_search messages", func(t *testing.T) { messages := []*schema.Message{ schema.UserMessage("hello"), {Role: schema.Tool, ToolName: "other_tool", Content: "some content"}, {Role: schema.Assistant, Content: "response"}, } selected, err := extractSelectedTools(ctx, messages) assert.NoError(t, err) assert.Empty(t, selected) }) t.Run("returns error for invalid json", func(t *testing.T) { messages := []*schema.Message{ {Role: schema.Tool, ToolName: toolSearchToolName, Content: "invalid json"}, } selected, err := extractSelectedTools(ctx, messages) assert.Error(t, err) assert.Nil(t, selected) }) } func TestInvertSelect(t *testing.T) { t.Run("returns items not in selected", func(t *testing.T) { all := []string{"a", "b", "c", "d"} selected := []string{"b", "d"} result := invertSelect(all, selected) assert.Len(t, result, 2) _, hasA := result["a"] _, hasC := result["c"] assert.True(t, hasA) assert.True(t, hasC) }) t.Run("empty selected returns all", func(t *testing.T) { all := []string{"a", "b", "c"} selected := []string{} result := invertSelect(all, selected) assert.Len(t, result, 3) }) t.Run("all selected returns empty", func(t *testing.T) { all := []string{"a", "b"} selected := []string{"a", "b"} result := invertSelect(all, selected) assert.Empty(t, result) }) t.Run("works with integers", func(t *testing.T) { all := []int{1, 2, 3, 4, 5} selected := []int{2, 4} result := invertSelect(all, selected) assert.Len(t, result, 3) _, has1 := result[1] _, has3 := result[3] _, has5 := result[5] assert.True(t, has1) assert.True(t, has3) assert.True(t, has5) }) } func TestRemoveTools(t *testing.T) { ctx := context.Background() t.Run("removes unselected dynamic tools", func(t *testing.T) { allTools := []*schema.ToolInfo{ {Name: "static_tool"}, {Name: "dynamic_tool1"}, {Name: "dynamic_tool2"}, {Name: "dynamic_tool3"}, } dynamicTools := []tool.BaseTool{ newMockTool("dynamic_tool1", ""), newMockTool("dynamic_tool2", ""), newMockTool("dynamic_tool3", ""), } result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} resultJSON, _ := json.Marshal(result) messages := []*schema.Message{ {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, } tools, err := removeTools(ctx, allTools, dynamicTools, messages) assert.NoError(t, err) assert.Len(t, tools, 2) toolNames := make([]string, len(tools)) for i, t := range tools { toolNames[i] = t.Name } assert.ElementsMatch(t, []string{"static_tool", "dynamic_tool1"}, toolNames) }) t.Run("remove all dynamic tools when no tool_search result", func(t *testing.T) { allTools := []*schema.ToolInfo{ {Name: "static_tool"}, {Name: "dynamic_tool1"}, } dynamicTools := []tool.BaseTool{ newMockTool("dynamic_tool1", ""), } messages := []*schema.Message{ schema.UserMessage("hello"), } tools, err := removeTools(ctx, allTools, dynamicTools, messages) assert.NoError(t, err) assert.Len(t, tools, 1) assert.Equal(t, "static_tool", tools[0].Name) }) t.Run("handles empty dynamic tools", func(t *testing.T) { allTools := []*schema.ToolInfo{ {Name: "static_tool1"}, {Name: "static_tool2"}, } dynamicTools := []tool.BaseTool{} messages := []*schema.Message{} tools, err := removeTools(ctx, allTools, dynamicTools, messages) assert.NoError(t, err) assert.Len(t, tools, 2) }) } type mockChatModel struct { generateFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) streamFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) } func (m *mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { if m.generateFunc != nil { return m.generateFunc(ctx, input, opts...) } return &schema.Message{Role: schema.Assistant, Content: "response"}, nil } func (m *mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { if m.streamFunc != nil { return m.streamFunc(ctx, input, opts...) } return nil, nil } func TestWrapper_Generate(t *testing.T) { ctx := context.Background() t.Run("filters tools based on tool_search result", func(t *testing.T) { allTools := []*schema.ToolInfo{ {Name: "static_tool"}, {Name: "dynamic_tool1"}, {Name: "dynamic_tool2"}, } dynamicTools := []tool.BaseTool{ newMockTool("dynamic_tool1", ""), newMockTool("dynamic_tool2", ""), } result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} resultJSON, _ := json.Marshal(result) messages := []*schema.Message{ schema.UserMessage("hello"), {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, } w := &wrapper{ allTools: allTools, dynamicTools: dynamicTools, cm: &mockChatModel{ generateFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { options := model.GetCommonOptions(nil, opts...) assert.Len(t, options.Tools, 2) assert.Equal(t, "static_tool", options.Tools[0].Name) assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) return nil, nil }, }, } _, err := w.Generate(ctx, messages) assert.NoError(t, err) }) } func TestWrapper_Stream(t *testing.T) { ctx := context.Background() t.Run("filters tools based on tool_search result", func(t *testing.T) { allTools := []*schema.ToolInfo{ {Name: "static_tool"}, {Name: "dynamic_tool1"}, {Name: "dynamic_tool2"}, } dynamicTools := []tool.BaseTool{ newMockTool("dynamic_tool1", ""), newMockTool("dynamic_tool2", ""), } result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} resultJSON, _ := json.Marshal(result) messages := []*schema.Message{ schema.UserMessage("hello"), {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, } w := &wrapper{ allTools: allTools, dynamicTools: dynamicTools, cm: &mockChatModel{ streamFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { options := model.GetCommonOptions(nil, opts...) assert.Len(t, options.Tools, 2) assert.Equal(t, "static_tool", options.Tools[0].Name) assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) return nil, nil }, }, } stream, err := w.Stream(ctx, messages) assert.NoError(t, err) assert.Nil(t, stream) }) } ================================================ FILE: adk/middlewares/filesystem/backend.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package filesystem provides middlewares. package filesystem import ( "github.com/cloudwego/eino/adk/filesystem" ) type FileInfo = filesystem.FileInfo type GrepMatch = filesystem.GrepMatch type LsInfoRequest = filesystem.LsInfoRequest type ReadRequest = filesystem.ReadRequest type GrepRequest = filesystem.GrepRequest type GlobInfoRequest = filesystem.GlobInfoRequest type WriteRequest = filesystem.WriteRequest type EditRequest = filesystem.EditRequest type FileContent = filesystem.FileContent ================================================ FILE: adk/middlewares/filesystem/filesystem.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filesystem import ( "context" "errors" "fmt" "io" "path/filepath" "runtime/debug" "sort" "strconv" "strings" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/components/tool/utils" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) const ( ToolNameLs = "ls" ToolNameReadFile = "read_file" ToolNameWriteFile = "write_file" ToolNameEditFile = "edit_file" ToolNameGlob = "glob" ToolNameGrep = "grep" ToolNameExecute = "execute" noFilesFound = "No files found" noMatchesFound = "No matches found" ) // ToolConfig configures a filesystem tool type ToolConfig struct { // Name overrides the tool name used in tool registration // optional, default tool name will be used if not set (empty string) Name string // Desc overrides the tool description used in tool registration // optional, default tool description will be used if not set (nil pointer) Desc *string // CustomTool provides a custom implementation for this tool. // If set, this custom tool will be used instead of the default implementation associated with Backend. // If not set, the default tool implementation associated with Backend will be created automatically. // optional CustomTool tool.BaseTool // Disable disables this tool // If true, the tool will not be registered // optional, false by default Disable bool } // Config is the configuration for the filesystem middleware type Config struct { // Backend provides filesystem operations used by tools and offloading. // If set, filesystem tools (read_file, write_file, edit_file, glob, grep) will be registered. // At least one of Backend, Shell, or StreamingShell must be set. Backend filesystem.Backend // Shell provides shell command execution capability. // If set, an execute tool will be registered to support shell command execution. // At least one of Backend, Shell, or StreamingShell must be set. // Mutually exclusive with StreamingShell. Shell filesystem.Shell // StreamingShell provides streaming shell command execution capability. // If set, a streaming execute tool will be registered to support streaming shell command execution. // At least one of Backend, Shell, or StreamingShell must be set. // Mutually exclusive with Shell. StreamingShell filesystem.StreamingShell // LsToolConfig configures the ls tool // optional LsToolConfig *ToolConfig // ReadFileToolConfig configures the read_file tool // optional ReadFileToolConfig *ToolConfig // WriteFileToolConfig configures the write_file tool // optional WriteFileToolConfig *ToolConfig // EditFileToolConfig configures the edit_file tool // optional EditFileToolConfig *ToolConfig // GlobToolConfig configures the glob tool // optional GlobToolConfig *ToolConfig // GrepToolConfig configures the grep tool // optional GrepToolConfig *ToolConfig // WithoutLargeToolResultOffloading disables automatic offloading of large tool result to Backend // optional, false(enabled) by default WithoutLargeToolResultOffloading bool // LargeToolResultOffloadingTokenLimit sets the token threshold to trigger offloading // optional, 20000 by default LargeToolResultOffloadingTokenLimit int // LargeToolResultOffloadingPathGen generates the write path for offloaded results based on context and ToolInput // optional, "/large_tool_result/{ToolCallID}" by default LargeToolResultOffloadingPathGen func(ctx context.Context, input *compose.ToolInput) (string, error) // CustomSystemPrompt overrides the default ToolsSystemPrompt appended to agent instruction // optional, ToolsSystemPrompt by default CustomSystemPrompt *string // CustomLsToolDesc overrides the ls tool description used in tool registration // optional, ListFilesToolDesc by default // Deprecated: Use LsToolConfig.Desc instead CustomLsToolDesc *string // CustomReadFileToolDesc overrides the read_file tool description // optional, ReadFileToolDesc by default // Deprecated: Use ReadFileToolConfig.Desc instead CustomReadFileToolDesc *string // CustomGrepToolDesc overrides the grep tool description // optional, GrepToolDesc by default // Deprecated: Use GrepToolConfig.Desc instead CustomGrepToolDesc *string // CustomGlobToolDesc overrides the glob tool description // optional, GlobToolDesc by default // Deprecated: Use GlobToolConfig.Desc instead CustomGlobToolDesc *string // CustomWriteFileToolDesc overrides the write_file tool description // optional, WriteFileToolDesc by default // Deprecated: Use WriteFileToolConfig.Desc instead CustomWriteFileToolDesc *string // CustomEditToolDesc overrides the edit_file tool description // optional, EditFileToolDesc by default // Deprecated: Use EditFileToolConfig.Desc instead CustomEditToolDesc *string } func (c *Config) Validate() error { if c == nil { return errors.New("config should not be nil") } if c.Backend == nil { return errors.New("backend should not be nil") } if c.StreamingShell != nil && c.Shell != nil { return errors.New("shell and streaming shell should not be both set") } return nil } // NewMiddleware constructs and returns the filesystem middleware. // // Deprecated: Use New instead. New returns // a ChatModelAgentMiddleware which provides better context propagation through wrapper methods // and is the recommended approach for new code. See ChatModelAgentMiddleware documentation // for details on the benefits over AgentMiddleware. func NewMiddleware(ctx context.Context, config *Config) (adk.AgentMiddleware, error) { err := config.Validate() if err != nil { return adk.AgentMiddleware{}, err } ts, err := getFilesystemTools(ctx, &MiddlewareConfig{ Backend: config.Backend, Shell: config.Shell, StreamingShell: config.StreamingShell, LsToolConfig: config.LsToolConfig, ReadFileToolConfig: config.ReadFileToolConfig, WriteFileToolConfig: config.WriteFileToolConfig, EditFileToolConfig: config.EditFileToolConfig, GlobToolConfig: config.GlobToolConfig, GrepToolConfig: config.GrepToolConfig, CustomSystemPrompt: config.CustomSystemPrompt, CustomLsToolDesc: config.CustomLsToolDesc, CustomReadFileToolDesc: config.CustomReadFileToolDesc, CustomGrepToolDesc: config.CustomGrepToolDesc, CustomGlobToolDesc: config.CustomGlobToolDesc, CustomWriteFileToolDesc: config.CustomWriteFileToolDesc, CustomEditToolDesc: config.CustomEditToolDesc, }) if err != nil { return adk.AgentMiddleware{}, err } var systemPrompt string if config.CustomSystemPrompt != nil { systemPrompt = *config.CustomSystemPrompt } m := adk.AgentMiddleware{ AdditionalInstruction: systemPrompt, AdditionalTools: ts, } if !config.WithoutLargeToolResultOffloading { m.WrapToolCall = newToolResultOffloading(ctx, &toolResultOffloadingConfig{ Backend: config.Backend, TokenLimit: config.LargeToolResultOffloadingTokenLimit, PathGenerator: config.LargeToolResultOffloadingPathGen, }) } return m, nil } // MiddlewareConfig is the configuration for the filesystem middleware type MiddlewareConfig struct { // Backend provides filesystem operations used by tools and offloading. // required Backend filesystem.Backend // Shell provides shell command execution capability. // If set, an execute tool will be registered to support shell command execution. // optional, mutually exclusive with StreamingShell Shell filesystem.Shell // StreamingShell provides streaming shell command execution capability. // If set, a streaming execute tool will be registered for real-time output. // optional, mutually exclusive with Shell StreamingShell filesystem.StreamingShell // LsToolConfig configures the ls tool // optional LsToolConfig *ToolConfig // ReadFileToolConfig configures the read_file tool // optional ReadFileToolConfig *ToolConfig // WriteFileToolConfig configures the write_file tool // optional WriteFileToolConfig *ToolConfig // EditFileToolConfig configures the edit_file tool // optional EditFileToolConfig *ToolConfig // GlobToolConfig configures the glob tool // optional GlobToolConfig *ToolConfig // GrepToolConfig configures the grep tool // optional GrepToolConfig *ToolConfig // CustomSystemPrompt overrides the default ToolsSystemPrompt appended to agent instruction // optional, ToolsSystemPrompt by default CustomSystemPrompt *string // CustomLsToolDesc overrides the ls tool description used in tool registration // optional, ListFilesToolDesc by default // Deprecated: Use LsToolConfig.Desc instead CustomLsToolDesc *string // CustomReadFileToolDesc overrides the read_file tool description // optional, ReadFileToolDesc by default // Deprecated: Use ReadFileToolConfig.Desc instead CustomReadFileToolDesc *string // CustomGrepToolDesc overrides the grep tool description // optional, GrepToolDesc by default // Deprecated: Use GrepToolConfig.Desc instead CustomGrepToolDesc *string // CustomGlobToolDesc overrides the glob tool description // optional, GlobToolDesc by default // Deprecated: Use GlobToolConfig.Desc instead CustomGlobToolDesc *string // CustomWriteFileToolDesc overrides the write_file tool description // optional, WriteFileToolDesc by default // Deprecated: Use WriteFileToolConfig.Desc instead CustomWriteFileToolDesc *string // CustomEditToolDesc overrides the edit_file tool description // optional, EditFileToolDesc by default // Deprecated: Use EditFileToolConfig.Desc instead CustomEditToolDesc *string } func (c *MiddlewareConfig) Validate() error { if c == nil { return errors.New("config should not be nil") } if c.Backend == nil { return errors.New("backend should not be nil") } if c.StreamingShell != nil && c.Shell != nil { return errors.New("shell and streaming shell should not be both set") } return nil } // mergeToolConfigWithDesc merges ToolConfig with legacy Desc field // Priority: ToolConfig.Desc > legacy Desc // Returns an empty ToolConfig if both are nil (to allow backend default implementation) func (c *MiddlewareConfig) mergeToolConfigWithDesc( toolConfig *ToolConfig, legacyDesc *string, ) *ToolConfig { if toolConfig == nil && legacyDesc == nil { return &ToolConfig{} } if toolConfig == nil { return &ToolConfig{ Desc: legacyDesc, } } if toolConfig.Desc == nil && legacyDesc != nil { merged := *toolConfig merged.Desc = legacyDesc return &merged } return toolConfig } // New constructs and returns the filesystem middleware as a ChatModelAgentMiddleware. // // This is the recommended constructor for new code. It returns a ChatModelAgentMiddleware which provides: // - Better context propagation through WrapInvokableToolCall and WrapStreamableToolCall methods // - BeforeAgent hook for modifying agent instruction and tools at runtime // - More flexible extension points compared to the struct-based AgentMiddleware // // The middleware provides filesystem tools (ls, read_file, write_file, edit_file, glob, grep) // and optionally an execute tool if the Backend implements ShellBackend or StreamingShellBackend. // // Example usage: // // middleware, err := filesystem.New(ctx, &filesystem.Config{ // Backend: myBackend, // }) // agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ // // ... // Handlers: []adk.ChatModelAgentMiddleware{middleware}, // }) func New(ctx context.Context, config *MiddlewareConfig) (adk.ChatModelAgentMiddleware, error) { err := config.Validate() if err != nil { return nil, err } ts, err := getFilesystemTools(ctx, config) if err != nil { return nil, err } var systemPrompt string if config.CustomSystemPrompt != nil { systemPrompt = *config.CustomSystemPrompt } m := &filesystemMiddleware{ additionalInstruction: systemPrompt, additionalTools: ts, } return m, nil } type filesystemMiddleware struct { adk.BaseChatModelAgentMiddleware additionalInstruction string additionalTools []tool.BaseTool } func (m *filesystemMiddleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { if runCtx == nil { return ctx, runCtx, nil } nRunCtx := *runCtx if m.additionalInstruction != "" { nRunCtx.Instruction = nRunCtx.Instruction + "\n" + m.additionalInstruction } nRunCtx.Tools = append(nRunCtx.Tools, m.additionalTools...) return ctx, &nRunCtx, nil } // toolSpec defines a specification for creating a filesystem tool. // It unifies the tool creation process by encapsulating the tool configuration, // legacy descriptor, and the creation function. type toolSpec struct { config *ToolConfig legacyDesc *string createFunc func(name, desc string) (tool.BaseTool, error) } func getFilesystemTools(_ context.Context, middlewareConfig *MiddlewareConfig) ([]tool.BaseTool, error) { var tools []tool.BaseTool toolSpecs := []toolSpec{ { config: middlewareConfig.LsToolConfig, legacyDesc: middlewareConfig.CustomLsToolDesc, createFunc: func(name, desc string) (tool.BaseTool, error) { if middlewareConfig.Backend != nil { return newLsTool(middlewareConfig.Backend, name, desc) } return nil, nil }, }, { config: middlewareConfig.ReadFileToolConfig, legacyDesc: middlewareConfig.CustomReadFileToolDesc, createFunc: func(name, desc string) (tool.BaseTool, error) { if middlewareConfig.Backend != nil { return newReadFileTool(middlewareConfig.Backend, name, desc) } return nil, nil }, }, { config: middlewareConfig.WriteFileToolConfig, legacyDesc: middlewareConfig.CustomWriteFileToolDesc, createFunc: func(name, desc string) (tool.BaseTool, error) { if middlewareConfig.Backend != nil { return newWriteFileTool(middlewareConfig.Backend, name, desc) } return nil, nil }, }, { config: middlewareConfig.EditFileToolConfig, legacyDesc: middlewareConfig.CustomEditToolDesc, createFunc: func(name, desc string) (tool.BaseTool, error) { if middlewareConfig.Backend != nil { return newEditFileTool(middlewareConfig.Backend, name, desc) } return nil, nil }, }, { config: middlewareConfig.GlobToolConfig, legacyDesc: middlewareConfig.CustomGlobToolDesc, createFunc: func(name, desc string) (tool.BaseTool, error) { if middlewareConfig.Backend != nil { return newGlobTool(middlewareConfig.Backend, name, desc) } return nil, nil }, }, { config: middlewareConfig.GrepToolConfig, legacyDesc: middlewareConfig.CustomGrepToolDesc, createFunc: func(name, desc string) (tool.BaseTool, error) { if middlewareConfig.Backend != nil { return newGrepTool(middlewareConfig.Backend, name, desc) } return nil, nil }, }, } for _, spec := range toolSpecs { t, err := createToolFromSpec(middlewareConfig, spec) if err != nil { return nil, err } if t != nil { tools = append(tools, t) } } // Create execute tool if Shell or StreamingShell is available if middlewareConfig.StreamingShell != nil { executeDesc, err := selectToolDesc("", ExecuteToolDesc, ExecuteToolDescChinese) if err != nil { return nil, err } executeTool, err := newStreamingExecuteTool(middlewareConfig.StreamingShell, ToolNameExecute, executeDesc) if err != nil { return nil, err } tools = append(tools, executeTool) } else if middlewareConfig.Shell != nil { executeDesc, err := selectToolDesc("", ExecuteToolDesc, ExecuteToolDescChinese) if err != nil { return nil, err } executeTool, err := newExecuteTool(middlewareConfig.Shell, ToolNameExecute, executeDesc) if err != nil { return nil, err } tools = append(tools, executeTool) } return tools, nil } // createToolFromSpec creates a tool instance based on the provided toolSpec. // It handles configuration merging (ToolConfig + legacy Desc), checks if the tool // is disabled, and prioritizes CustomTool over the default implementation. func createToolFromSpec(middlewareConfig *MiddlewareConfig, spec toolSpec) (tool.BaseTool, error) { mergedConfig := middlewareConfig.mergeToolConfigWithDesc(spec.config, spec.legacyDesc) if mergedConfig.Disable { return nil, nil } return getOrCreateTool(mergedConfig.CustomTool, func() (tool.BaseTool, error) { desc := "" if mergedConfig.Desc != nil { desc = *mergedConfig.Desc } return spec.createFunc(mergedConfig.Name, desc) }) } func getOrCreateTool(customTool tool.BaseTool, createFunc func() (tool.BaseTool, error)) (tool.BaseTool, error) { if customTool != nil { return customTool, nil } return createFunc() } type lsArgs struct { Path string `json:"path"` } func newLsTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameLs) d, err := selectToolDesc(desc, ListFilesToolDesc, ListFilesToolDescChinese) if err != nil { return nil, err } return utils.InferTool(toolName, d, func(ctx context.Context, input lsArgs) (string, error) { infos, err := fs.LsInfo(ctx, &filesystem.LsInfoRequest{Path: input.Path}) if err != nil { return "", err } if len(infos) == 0 { return noFilesFound, nil } paths := make([]string, 0, len(infos)) for _, fi := range infos { paths = append(paths, fi.Path) } return strings.Join(paths, "\n"), nil }) } type readFileArgs struct { // FilePath is the path to the file to read. FilePath string `json:"file_path" jsonschema:"description=The path to the file to read"` // Offset is the line number to start reading from. Offset int `json:"offset" jsonschema:"description=The line number to start reading from. Only provide if the file is too large to read at once"` // Limit is the number of lines to read. Limit int `json:"limit" jsonschema:"description=The number of lines to read. Only provide if the file is too large to read at once."` } func newReadFileTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameReadFile) d, err := selectToolDesc(desc, ReadFileToolDesc, ReadFileToolDescChinese) if err != nil { return nil, err } return utils.InferTool(toolName, d, func(ctx context.Context, input readFileArgs) (string, error) { if input.Offset <= 0 { input.Offset = 1 } fileCt, err := fs.Read(ctx, &filesystem.ReadRequest{ FilePath: input.FilePath, Offset: input.Offset, Limit: input.Limit, }) if err != nil { return "", err } startLine := input.Offset lines := strings.Split(fileCt.Content, "\n") var b strings.Builder for i, line := range lines { if i < len(lines)-1 { fmt.Fprintf(&b, "%6d\t%s\n", startLine+i, line) } else { fmt.Fprintf(&b, "%6d\t%s", startLine+i, line) } } return b.String(), nil }) } type writeFileArgs struct { // FilePath is the path to the file to write. FilePath string `json:"file_path" jsonschema:"description=The path to the file to write"` // Content is the content to write to the file. Content string `json:"content" jsonschema:"description=The content to write to the file"` } func newWriteFileTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameWriteFile) d, err := selectToolDesc(desc, WriteFileToolDesc, WriteFileToolDescChinese) if err != nil { return nil, err } return utils.InferTool(toolName, d, func(ctx context.Context, input writeFileArgs) (string, error) { err := fs.Write(ctx, &filesystem.WriteRequest{ FilePath: input.FilePath, Content: input.Content, }) if err != nil { return "", err } return fmt.Sprintf("Updated file %s", input.FilePath), nil }) } type editFileArgs struct { // FilePath is the path to the file to modify. FilePath string `json:"file_path" jsonschema:"description=The path to the file to modify"` // OldString is the text to replace. OldString string `json:"old_string" jsonschema:"description=The text to replace"` // NewString is the text to replace it with. NewString string `json:"new_string" jsonschema:"description=The text to replace it with (must be different from old_string)"` // ReplaceAll indicates whether to replace all occurrences of old_string. ReplaceAll bool `json:"replace_all" jsonschema:"description=Replace all occurrences of old_string (default false),default=false"` } func newEditFileTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameEditFile) d, err := selectToolDesc(desc, EditFileToolDesc, EditFileToolDescChinese) if err != nil { return nil, err } return utils.InferTool(toolName, d, func(ctx context.Context, input editFileArgs) (string, error) { err := fs.Edit(ctx, &filesystem.EditRequest{ FilePath: input.FilePath, OldString: input.OldString, NewString: input.NewString, ReplaceAll: input.ReplaceAll, }) if err != nil { return "", err } return fmt.Sprintf("Successfully replaced the string in '%s'", input.FilePath), nil }) } type globArgs struct { // Pattern is the glob pattern to match files against. Pattern string `json:"pattern" jsonschema:"description=The glob pattern to match files against"` // Path is the directory to search in. Path string `json:"path" jsonschema:"description=The directory to search in. If not specified\\, the current working directory will be used. IMPORTANT: Omit this field to use the default directory. DO NOT enter 'undefined' or 'null' - simply omit it for the default behavior. Must be a valid directory path if provided."` } func newGlobTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameGlob) d, err := selectToolDesc(desc, GlobToolDesc, GlobToolDescChinese) if err != nil { return nil, err } return utils.InferTool(toolName, d, func(ctx context.Context, input globArgs) (string, error) { infos, err := fs.GlobInfo(ctx, &filesystem.GlobInfoRequest{ Pattern: input.Pattern, Path: input.Path, }) if err != nil { return "", err } if len(infos) == 0 { return noFilesFound, nil } paths := make([]string, 0, len(infos)) for _, fi := range infos { paths = append(paths, fi.Path) } return strings.Join(paths, "\n"), nil }) } type grepArgs struct { // Pattern is the regular expression pattern to search for in file contents. Pattern string `json:"pattern" jsonschema:"description=The regular expression pattern to search for in file contents"` // Path is the file or directory to search in. Defaults to current working directory. Path *string `json:"path,omitempty" jsonschema:"description=File or directory to search in (rg PATH). Defaults to current working directory."` // Glob is the glob pattern to filter files (e.g. "*.js", "*.{ts,tsx}"). Glob *string `json:"glob,omitempty" jsonschema:"description=Glob pattern to filter files (e.g. '*.js'\\, '*.{ts\\,tsx}') - maps to rg --glob"` // OutputMode specifies the output format. // "content" shows matching lines (supports context, line numbers, head_limit). // "files_with_matches" shows file paths (supports head_limit). // "count" shows match counts (supports head_limit). // Defaults to "files_with_matches". OutputMode string `json:"output_mode,omitempty" jsonschema:"description=Output mode: 'content' shows matching lines (supports -A/-B/-C context\\, -n line numbers\\, head_limit)\\, 'files_with_matches' shows file paths (supports head_limit)\\, 'count' shows match counts (supports head_limit). Defaults to 'files_with_matches'.,enum=content,enum=files_with_matches,enum=count"` // Context is the number of lines to show before and after each match. // Only applicable when output_mode is "content". Context *int `json:"-C,omitempty" jsonschema:"description=Number of lines to show before and after each match (rg -C). Requires output_mode: 'content'\\, ignored otherwise."` // BeforeLines is the number of lines to show before each match. // Only applicable when output_mode is "content". BeforeLines *int `json:"-B,omitempty" jsonschema:"description=Number of lines to show before each match (rg -B). Requires output_mode: 'content'\\, ignored otherwise."` // AfterLines is the number of lines to show after each match. // Only applicable when output_mode is "content". AfterLines *int `json:"-A,omitempty" jsonschema:"description=Number of lines to show after each match (rg -A). Requires output_mode: 'content'\\, ignored otherwise."` // ShowLineNumbers enables showing line numbers in output. // Only applicable when output_mode is "content". Defaults to true. ShowLineNumbers *bool `json:"-n,omitempty" jsonschema:"description=Show line numbers in output (rg -n). Requires output_mode: 'content'\\, ignored otherwise. Defaults to true."` // CaseInsensitive enables case insensitive search. CaseInsensitive *bool `json:"-i,omitempty" jsonschema:"description=Case insensitive search (rg -i)"` // FileType is the file type to search (e.g., js, py, rust, go, java). // More efficient than Glob for standard file types. FileType *string `json:"type,omitempty" jsonschema:"description=File type to search (rg --type). Common types: js\\, py\\, rust\\, go\\, java\\, etc. More efficient than include for standard file types."` // HeadLimit limits output to first N lines/entries. // Works across all output modes. Defaults to 0 (unlimited). HeadLimit *int `json:"head_limit,omitempty" jsonschema:"description=Limit output to first N lines/entries\\, equivalent to '| head -N'. Works across all output modes: content (limits output lines)\\, files_with_matches (limits file paths)\\, count (limits count entries). Defaults to 0 (unlimited)."` // Offset skips first N lines/entries before applying HeadLimit. // Works across all output modes. Defaults to 0. Offset *int `json:"offset,omitempty" jsonschema:"description=Skip first N lines/entries before applying head_limit\\, equivalent to '| tail -n +N | head -N'. Works across all output modes. Defaults to 0."` // Multiline enables multiline mode where patterns can span lines. // - true: Allows patterns to match across lines, "." matches newlines // - false: Default, matches only within single lines Multiline *bool `json:"multiline,omitempty" jsonschema:"description=Enable multiline mode where . matches newlines and patterns can span lines (rg -U --multiline-dotall). Default: false."` } func newGrepTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameGrep) d, err := selectToolDesc(desc, GrepToolDesc, GrepToolDescChinese) if err != nil { return nil, err } return utils.InferTool(toolName, d, func(ctx context.Context, input grepArgs) (string, error) { // Extract string parameters path := valueOrDefault(input.Path, "") glob := valueOrDefault(input.Glob, "") fileType := valueOrDefault(input.FileType, "") var beforeLines, afterLines int if input.Context != nil { beforeLines = valueOrDefault(input.Context, 0) afterLines = valueOrDefault(input.Context, 0) } else { // Extract context parameters beforeLines = valueOrDefault(input.BeforeLines, 0) afterLines = valueOrDefault(input.AfterLines, 0) } // Extract boolean flags caseInsensitive := valueOrDefault(input.CaseInsensitive, false) enableMultiline := valueOrDefault(input.Multiline, false) // Extract pagination parameters headLimit := valueOrDefault(input.HeadLimit, 0) offset := valueOrDefault(input.Offset, 0) matches, err := fs.GrepRaw(ctx, &filesystem.GrepRequest{ Pattern: input.Pattern, Path: path, Glob: glob, FileType: fileType, CaseInsensitive: caseInsensitive, AfterLines: afterLines, BeforeLines: beforeLines, EnableMultiline: enableMultiline, }) if err != nil { return "", err } sort.SliceStable(matches, func(i, j int) bool { return filepath.Base(matches[i].Path) < filepath.Base(matches[j].Path) }) switch input.OutputMode { case "content": matches = applyPagination(matches, offset, headLimit) return formatContentMatches(matches, valueOrDefault(input.ShowLineNumbers, true)), nil case "count": return formatCountMatches(matches, offset, headLimit), nil case "files_with_matches": return formatFileMatches(matches, offset, headLimit), nil default: return formatFileMatches(matches, offset, headLimit), nil } }) } type executeArgs struct { Command string `json:"command"` } func newExecuteTool(sb filesystem.Shell, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameExecute) d, err := selectToolDesc(desc, ExecuteToolDesc, ExecuteToolDescChinese) if err != nil { return nil, err } return utils.InferTool(toolName, d, func(ctx context.Context, input executeArgs) (string, error) { result, err := sb.Execute(ctx, &filesystem.ExecuteRequest{ Command: input.Command, }) if err != nil { return "", err } return convExecuteResponse(result), nil }) } func newStreamingExecuteTool(sb filesystem.StreamingShell, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameExecute) d, err := selectToolDesc(desc, ExecuteToolDesc, ExecuteToolDescChinese) if err != nil { return nil, err } return utils.InferStreamTool(toolName, d, func(ctx context.Context, input executeArgs) (*schema.StreamReader[string], error) { result, err := sb.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{ Command: input.Command, }) if err != nil { return nil, err } sr, sw := schema.Pipe[string](10) go func() { defer func() { e := recover() if e != nil { sw.Send("", fmt.Errorf("panic: %v,\n stack: %s", e, string(debug.Stack()))) } sw.Close() }() var hasSentContent bool var exitCode *int for { chunk, recvErr := result.Recv() if recvErr == io.EOF { break } if recvErr != nil { sw.Send("", recvErr) return } if chunk == nil { continue } if chunk.ExitCode != nil { exitCode = chunk.ExitCode } parts := make([]string, 0, 2) if chunk.Output != "" { parts = append(parts, chunk.Output) } if chunk.Truncated { parts = append(parts, "[Output was truncated due to size limits]") } if len(parts) > 0 { sw.Send(strings.Join(parts, "\n"), nil) hasSentContent = true } } if exitCode != nil && *exitCode != 0 { sw.Send(fmt.Sprintf("\n[Command failed with exit code %d]", *exitCode), nil) } else if !hasSentContent { sw.Send("[Command executed successfully with no output]", nil) } }() return sr, nil }) } func convExecuteResponse(response *filesystem.ExecuteResponse) string { if response == nil { return "" } parts := []string{response.Output} if response.ExitCode != nil && *response.ExitCode != 0 { parts = append(parts, fmt.Sprintf("[Command failed with exit code %d]", *response.ExitCode)) } if response.Truncated { parts = append(parts, "[Output was truncated due to size limits]") } result := strings.Join(parts, "\n") if result == "" && (response.ExitCode == nil || *response.ExitCode == 0) { return "[Command executed successfully with no output]" } return result } // valueOrDefault returns the value pointed to by ptr, or defaultValue if ptr is nil. func valueOrDefault[T any](ptr *T, defaultValue T) T { if ptr != nil { return *ptr } return defaultValue } func applyPagination[T any](items []T, offset, headLimit int) []T { if offset < 0 { offset = 0 } if offset >= len(items) { return []T{} } items = items[offset:] if headLimit > 0 && headLimit < len(items) { items = items[:headLimit] } return items } func formatFileMatches(matches []filesystem.GrepMatch, offset, headLimit int) string { if len(matches) == 0 { return noFilesFound } seen := make(map[string]bool) var uniquePaths []string for _, match := range matches { if !seen[match.Path] { seen[match.Path] = true uniquePaths = append(uniquePaths, match.Path) } } totalFiles := len(uniquePaths) uniquePaths = applyPagination(uniquePaths, offset, headLimit) fileWord := "files" if totalFiles == 1 { fileWord = "file" } return fmt.Sprintf("Found %d %s\n%s", totalFiles, fileWord, strings.Join(uniquePaths, "\n")) } func formatContentMatches(matches []filesystem.GrepMatch, showLineNum bool) string { if len(matches) == 0 { return noMatchesFound } var b strings.Builder for _, match := range matches { b.WriteString(match.Path) if showLineNum { b.WriteString(":") b.WriteString(strconv.Itoa(match.Line)) } b.WriteString(":") b.WriteString(match.Content) b.WriteString("\n") } return strings.TrimSuffix(b.String(), "\n") } func formatCountMatches(matches []filesystem.GrepMatch, offset, headLimit int) string { countMap := make(map[string]int) for _, match := range matches { countMap[match.Path]++ } var paths []string for path := range countMap { paths = append(paths, path) } sort.Strings(paths) totalOccurrences := len(matches) totalFiles := len(paths) occurrenceWord := "occurrences" if totalOccurrences == 1 { occurrenceWord = "occurrence" } fileWord := "files" if totalFiles == 1 { fileWord = "file" } if totalOccurrences == 0 { return fmt.Sprintf("%s\n\nFound %d total %s across %d %s.", noMatchesFound, totalOccurrences, occurrenceWord, totalFiles, fileWord) } paths = applyPagination(paths, offset, headLimit) var b strings.Builder for _, path := range paths { b.WriteString(path) b.WriteString(":") b.WriteString(strconv.Itoa(countMap[path])) b.WriteString("\n") } result := strings.TrimSuffix(b.String(), "\n") return fmt.Sprintf("%s\n\nFound %d total %s across %d %s.", result, totalOccurrences, occurrenceWord, totalFiles, fileWord) } // selectToolDesc returns the custom description if provided, otherwise selects the appropriate // i18n description based on the current language setting. func selectToolDesc(customDesc string, defaultEnglish, defaultChinese string) (string, error) { if customDesc != "" { return customDesc, nil } return internal.SelectPrompt(internal.I18nPrompts{ English: defaultEnglish, Chinese: defaultChinese, }), nil } // selectToolName returns the custom tool name if provided, otherwise returns the default name. func selectToolName(customName string, defaultName string) string { if customName != "" { return customName } return defaultName } ================================================ FILE: adk/middlewares/filesystem/filesystem_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filesystem import ( "context" "errors" "fmt" "io" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) // setupTestBackend creates a test backend with some initial files func setupTestBackend() *filesystem.InMemoryBackend { backend := filesystem.NewInMemoryBackend() ctx := context.Background() // Create test files backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/file1.txt", Content: "line1\nline2\nline3\nline4\nline5", }) backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/file2.go", Content: "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}", }) backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/dir1/file3.txt", Content: "hello world\nfoo bar\nhello again", }) backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/dir1/file4.py", Content: "print('hello')\nprint('world')", }) backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/dir2/file5.go", Content: "package test\n\nfunc test() {}", }) return backend } // invokeTool is a helper to invoke a tool with JSON input func invokeTool(_ *testing.T, bt tool.BaseTool, input string) (string, error) { ctx := context.Background() result, err := bt.(tool.InvokableTool).InvokableRun(ctx, input) if err != nil { return "", err } return result, nil } func TestLsTool(t *testing.T) { backend := setupTestBackend() lsTool, err := newLsTool(backend, "", "") if err != nil { t.Fatalf("Failed to create ls tool: %v", err) } tests := []struct { name string input string expected []string // expected paths in output }{ { name: "list root", input: `{"path": "/"}`, expected: []string{"file1.txt", "file2.go", "dir1", "dir2"}, }, { name: "list empty path (defaults to root)", input: `{"path": ""}`, expected: []string{"file1.txt", "file2.go", "dir1", "dir2"}, }, { name: "list dir1", input: `{"path": "/dir1"}`, expected: []string{"file3.txt", "file4.py"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := invokeTool(t, lsTool, tt.input) if err != nil { t.Fatalf("ls tool failed: %v", err) } for _, expectedPath := range tt.expected { if !strings.Contains(result, expectedPath) { t.Errorf("Expected output to contain %q, got: %s", expectedPath, result) } } }) } } func TestReadFileTool(t *testing.T) { backend := setupTestBackend() readTool, err := newReadFileTool(backend, "", "") if err != nil { t.Fatalf("Failed to create read_file tool: %v", err) } tests := []struct { name string input string expected string shouldError bool }{ { name: "read full file", input: `{"file_path": "/file1.txt", "offset": 0, "limit": 100}`, expected: " 1\tline1\n 2\tline2\n 3\tline3\n 4\tline4\n 5\tline5", }, { name: "read with offset", input: `{"file_path": "/file1.txt", "offset": 2, "limit": 2}`, expected: " 2\tline2\n 3\tline3", }, { name: "read with default limit", input: `{"file_path": "/file1.txt", "offset": 0, "limit": 0}`, expected: " 1\tline1\n 2\tline2\n 3\tline3\n 4\tline4\n 5\tline5", }, { name: "read with negative offset (treated as 0)", input: `{"file_path": "/file1.txt", "offset": -1, "limit": 2}`, expected: " 1\tline1\n 2\tline2", }, { name: "read non-existent file", input: `{"file_path": "/nonexistent.txt", "offset": 0, "limit": 10}`, shouldError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := invokeTool(t, readTool, tt.input) if tt.shouldError { if err == nil { t.Error("Expected error but got none") } return } if err != nil { t.Fatalf("read_file tool failed: %v", err) } if result != tt.expected { t.Errorf("Expected %q, got %q", tt.expected, result) } }) } } func TestWriteFileTool(t *testing.T) { backend := setupTestBackend() writeTool, err := newWriteFileTool(backend, "", "") if err != nil { t.Fatalf("Failed to create write_file tool: %v", err) } tests := []struct { name string input string expected string isError bool }{ { name: "write new file", input: `{"file_path": "/newfile.txt", "content": "new content"}`, expected: "Updated file /newfile.txt", }, { name: "overwrite existing file", input: `{"file_path": "/file1.txt", "content": "overwritten"}`, isError: false, expected: "Updated file /file1.txt", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := invokeTool(t, writeTool, tt.input) if tt.isError { if err == nil { t.Errorf("Expected an error, but got none") } return } if err != nil { t.Fatalf("write_file tool failed: %v", err) } if result != tt.expected { t.Errorf("Expected %q, got %q", tt.expected, result) } }) } // Verify the file was actually written ctx := context.Background() content, err := backend.Read(ctx, &filesystem.ReadRequest{ FilePath: "/newfile.txt", Offset: 0, Limit: 100, }) if err != nil { t.Fatalf("Failed to read written file: %v", err) } if content.Content != "new content" { t.Errorf("Expected written content to be 'new content', got %q", content) } } func TestEditFileTool(t *testing.T) { backend := setupTestBackend() editTool, err := newEditFileTool(backend, "", "") if err != nil { t.Fatalf("Failed to create edit_file tool: %v", err) } tests := []struct { name string setupFile string setupContent string input string expected string shouldError bool }{ { name: "replace first occurrence", setupFile: "/edit1.txt", setupContent: "hello world\nhello again\nhello world", input: `{"file_path": "/edit1.txt", "old_string": "hello again", "new_string": "hi", "replace_all": false}`, expected: "hello world\nhi\nhello world", }, { name: "replace all occurrences", setupFile: "/edit2.txt", setupContent: "hello world\nhello again\nhello world", input: `{"file_path": "/edit2.txt", "old_string": "hello", "new_string": "hi", "replace_all": true}`, expected: "hi world\nhi again\nhi world", }, { name: "non-existent file", setupFile: "", setupContent: "", input: `{"file_path": "/nonexistent.txt", "old_string": "old", "new_string": "new", "replace_all": false}`, shouldError: true, }, { name: "empty old_string", setupFile: "/edit3.txt", setupContent: "content", input: `{"file_path": "/edit3.txt", "old_string": "", "new_string": "new", "replace_all": false}`, shouldError: true, }, } ctx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Setup file if needed if tt.setupFile != "" { backend.Write(ctx, &filesystem.WriteRequest{ FilePath: tt.setupFile, Content: tt.setupContent, }) } _, err := invokeTool(t, editTool, tt.input) if tt.shouldError { if err == nil { t.Error("Expected error but got none") } return } if err != nil { t.Fatalf("edit_file tool failed: %v", err) } result, err := backend.Read(ctx, &filesystem.ReadRequest{ FilePath: tt.setupFile, Offset: 0, Limit: 0, }) if err != nil { t.Fatalf("edit_file tool failed: %v", err) } if result.Content != tt.expected { t.Errorf("Expected %q, got %q", tt.expected, result.Content) } }) } } func TestGlobTool(t *testing.T) { backend := setupTestBackend() globTool, err := newGlobTool(backend, "", "") if err != nil { t.Fatalf("Failed to create glob tool: %v", err) } tests := []struct { name string input string expected []string }{ { name: "match all .txt files in root", input: `{"pattern": "*.txt", "path": "/"}`, expected: []string{"file1.txt"}, }, { name: "match all .go files in root", input: `{"pattern": "*.go", "path": "/"}`, expected: []string{"file2.go"}, }, { name: "match all .txt files in dir1", input: `{"pattern": "*.txt", "path": "/dir1"}`, expected: []string{"file3.txt"}, }, { name: "match all .py files in dir1", input: `{"pattern": "*.py", "path": "/dir1"}`, expected: []string{"file4.py"}, }, { name: "empty path defaults to root", input: `{"pattern": "*.go", "path": ""}`, expected: []string{"file2.go"}, }, { name: "match all .txt files in dir1 in root dir", input: `{"pattern": "/dir1/*.txt", "path": "/"}`, expected: []string{"/dir1/file3.txt"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := invokeTool(t, globTool, tt.input) if err != nil { t.Fatalf("glob tool failed: %v", err) } for _, expectedPath := range tt.expected { if !strings.Contains(result, expectedPath) { t.Errorf("Expected output to contain %q, got: %s", expectedPath, result) } } }) } } func TestGrepTool(t *testing.T) { backend := setupTestBackend() grepTool, err := newGrepTool(backend, "", "") if err != nil { t.Fatalf("Failed to create grep tool: %v", err) } tests := []struct { name string input string expected string contains []string }{ { name: "grep with count mode", input: `{"pattern": "hello", "output_mode": "count"}`, expected: "/dir1/file3.txt:2\n/dir1/file4.py:1\n/file2.go:1\n\nFound 4 total occurrences across 3 files.", // 2 in file3.txt, 1 in file4.py, 1 in file2.go }, { name: "grep with content mode", input: `{"pattern": "hello", "output_mode": "content"}`, contains: []string{"/dir1/file3.txt:1:hello world", "/dir1/file3.txt:3:hello again", "/dir1/file4.py:1:print('hello')"}, }, { name: "grep with files_with_matches mode (default)", input: `{"pattern": "hello", "output_mode": "files_with_matches"}`, contains: []string{"/dir1/file3.txt", "/dir1/file4.py"}, }, { name: "grep with glob filter", input: `{"pattern": "hello", "glob": "*.txt", "output_mode": "count"}`, expected: "/dir1/file3.txt:2\n\nFound 2 total occurrences across 1 file.", // only in file3.txt }, { name: "grep withpath filter", input: `{"pattern": "package", "path": "/dir2", "output_mode": "count"}`, expected: "/dir2/file5.go:1\n\nFound 1 total occurrence across 1 file.", // only in dir2/file5.go }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := invokeTool(t, grepTool, tt.input) if err != nil { t.Fatalf("grep tool failed: %v", err) } if tt.expected != "" { if result != tt.expected { t.Errorf("Expected %q, got %q", tt.expected, result) } } for _, expectedStr := range tt.contains { if !strings.Contains(result, expectedStr) { t.Errorf("Expected output to contain %q, got: %s", expectedStr, result) } } }) } } func TestExecuteTool(t *testing.T) { backend := setupTestBackend() tests := []struct { name string resp *filesystem.ExecuteResponse input string expected string shouldError bool }{ { name: "successful command execution", resp: &filesystem.ExecuteResponse{ Output: "hello world", ExitCode: ptrOf(0), }, input: `{"command": "echo hello world"}`, expected: "hello world", }, { name: "command with non-zero exit code", resp: &filesystem.ExecuteResponse{ Output: "error: file not found", ExitCode: ptrOf(1), }, input: `{"command": "cat nonexistent.txt"}`, expected: "error: file not found\n[Command failed with exit code 1]", }, { name: "command with truncated output", resp: &filesystem.ExecuteResponse{ Output: "partial output...", ExitCode: ptrOf(0), Truncated: true, }, input: `{"command": "cat largefile.txt"}`, expected: "partial output...\n[Output was truncated due to size limits]", }, { name: "command with both non-zero exit code and truncated output", resp: &filesystem.ExecuteResponse{ Output: "error output...", ExitCode: ptrOf(2), Truncated: true, }, input: `{"command": "failing command"}`, expected: "error output...\n[Command failed with exit code 2]\n[Output was truncated due to size limits]", }, { name: "successful command with no output", resp: &filesystem.ExecuteResponse{ Output: "", ExitCode: ptrOf(0), }, input: `{"command": "mkdir /tmp/test"}`, expected: "[Command executed successfully with no output]", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { executeTool, err := newExecuteTool(&mockShellBackend{ Backend: backend, resp: tt.resp, }, "", "") assert.NoError(t, err) result, err := invokeTool(t, executeTool, tt.input) if tt.shouldError { assert.Error(t, err) return } assert.NoError(t, err) assert.Equal(t, tt.expected, result) }) } } func ptrOf[T any](t T) *T { return &t } type mockShellBackend struct { filesystem.Backend resp *filesystem.ExecuteResponse } func (m *mockShellBackend) Execute(ctx context.Context, req *filesystem.ExecuteRequest) (*filesystem.ExecuteResponse, error) { return m.resp, nil } func TestGetFilesystemTools(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("returns 6 tools for regular Backend", func(t *testing.T) { tools, err := getFilesystemTools(ctx, &MiddlewareConfig{Backend: backend}) assert.NoError(t, err) assert.Len(t, tools, 6) // Verify tool names toolNames := make([]string, 0, len(tools)) for _, to := range tools { info, _ := to.Info(ctx) toolNames = append(toolNames, info.Name) } assert.Contains(t, toolNames, "ls") assert.Contains(t, toolNames, "read_file") assert.Contains(t, toolNames, "write_file") assert.Contains(t, toolNames, "edit_file") assert.Contains(t, toolNames, "glob") assert.Contains(t, toolNames, "grep") }) t.Run("returns 7 tools for Shell", func(t *testing.T) { shellBackend := &mockShellBackend{ Backend: backend, resp: &filesystem.ExecuteResponse{Output: "ok"}, } tools, err := getFilesystemTools(ctx, &MiddlewareConfig{Backend: shellBackend, Shell: shellBackend}) assert.NoError(t, err) assert.Len(t, tools, 7) // Verify execute tool is included toolNames := make([]string, 0, len(tools)) for _, to := range tools { info, _ := to.Info(ctx) toolNames = append(toolNames, info.Name) } assert.Contains(t, toolNames, "execute") }) t.Run("custom tool descriptions", func(t *testing.T) { customLsDesc := "Custom ls description" customReadDesc := "Custom read description" tools, err := getFilesystemTools(ctx, &MiddlewareConfig{ Backend: backend, CustomLsToolDesc: &customLsDesc, CustomReadFileToolDesc: &customReadDesc, }) assert.NoError(t, err) assert.Len(t, tools, 6) // Verify custom descriptions are applied for _, to := range tools { info, _ := to.Info(ctx) if info.Name == "ls" { assert.Equal(t, customLsDesc, info.Desc) } if info.Name == "read_file" { assert.Equal(t, customReadDesc, info.Desc) } } }) } func TestNew(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("nil config returns error", func(t *testing.T) { _, err := New(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "config should not be nil") }) t.Run("nil backend returns error", func(t *testing.T) { _, err := New(ctx, &MiddlewareConfig{Backend: nil}) assert.Error(t, err) assert.Contains(t, err.Error(), "backend should not be nil") }) t.Run("valid config with default settings", func(t *testing.T) { m, err := New(ctx, &MiddlewareConfig{Backend: backend}) assert.NoError(t, err) assert.NotNil(t, m) fm, ok := m.(*filesystemMiddleware) assert.True(t, ok) assert.Len(t, fm.additionalTools, 6) }) t.Run("custom system prompt", func(t *testing.T) { customPrompt := "Custom system prompt" m, err := New(ctx, &MiddlewareConfig{ Backend: backend, CustomSystemPrompt: &customPrompt, }) assert.NoError(t, err) fm, ok := m.(*filesystemMiddleware) assert.True(t, ok) assert.Equal(t, customPrompt, fm.additionalInstruction) }) t.Run("ShellBackend adds execute tool", func(t *testing.T) { shellBackend := &mockShellBackend{ Backend: backend, resp: &filesystem.ExecuteResponse{Output: "ok"}, } m, err := New(ctx, &MiddlewareConfig{Backend: shellBackend, Shell: shellBackend}) assert.NoError(t, err) fm, ok := m.(*filesystemMiddleware) assert.True(t, ok) assert.Len(t, fm.additionalTools, 7) }) } func TestFilesystemMiddleware_BeforeAgent(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("adds instruction and tools to context", func(t *testing.T) { m, err := New(ctx, &MiddlewareConfig{Backend: backend}) assert.NoError(t, err) runCtx := &adk.ChatModelAgentContext{ Instruction: "Original instruction", Tools: nil, } newCtx, newRunCtx, err := m.BeforeAgent(ctx, runCtx) assert.NoError(t, err) assert.NotNil(t, newCtx) assert.NotNil(t, newRunCtx) assert.Contains(t, newRunCtx.Instruction, "Original instruction") assert.Len(t, newRunCtx.Tools, 6) }) t.Run("nil runCtx returns nil", func(t *testing.T) { m, err := New(ctx, &MiddlewareConfig{Backend: backend}) assert.NoError(t, err) newCtx, newRunCtx, err := m.BeforeAgent(ctx, nil) assert.NoError(t, err) assert.NotNil(t, newCtx) assert.Nil(t, newRunCtx) }) } func TestFilesystemMiddleware_WrapInvokableToolCall(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("small result passes through unchanged", func(t *testing.T) { m, err := New(ctx, &MiddlewareConfig{Backend: backend}) assert.NoError(t, err) endpoint := func(ctx context.Context, args string, opts ...tool.Option) (string, error) { return "small result", nil } tCtx := &adk.ToolContext{Name: "test_tool", CallID: "call-1"} wrapped, err := m.WrapInvokableToolCall(ctx, endpoint, tCtx) assert.NoError(t, err) result, err := wrapped(ctx, "{}") assert.NoError(t, err) assert.Equal(t, "small result", result) }) } func TestGrepToolWithSortingAndPagination(t *testing.T) { backend := filesystem.NewInMemoryBackend() ctx := context.Background() backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/zebra.txt", Content: "match1\nmatch2\nmatch3", }) backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/apple.txt", Content: "match4\nmatch5", }) backend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/banana.txt", Content: "match6\nmatch7\nmatch8", }) grepTool, err := newGrepTool(backend, "", "") assert.NoError(t, err) t.Run("files sorted by basename", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "files_with_matches"}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 4, len(lines)) // 1 summary + 3 files assert.Contains(t, lines[0], "Found 3 files") assert.Contains(t, lines[1], "apple.txt") assert.Contains(t, lines[2], "banana.txt") assert.Contains(t, lines[3], "zebra.txt") }) t.Run("files_with_matches with offset", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "files_with_matches", "offset": 1}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 3, len(lines)) // 1 summary + 2 files (pagination applied) assert.Contains(t, lines[0], "Found 3 files") // total count before pagination assert.Contains(t, lines[1], "banana.txt") assert.Contains(t, lines[2], "zebra.txt") }) t.Run("files_with_matches with head_limit", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "files_with_matches", "head_limit": 2}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 3, len(lines)) // 1 summary + 2 files (pagination applied) assert.Contains(t, lines[0], "Found 3 files") // total count before pagination assert.Contains(t, lines[1], "apple.txt") assert.Contains(t, lines[2], "banana.txt") }) t.Run("files_with_matches with offset and head_limit", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "files_with_matches", "offset": 1, "head_limit": 1}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 2, len(lines)) // 1 summary + 1 file (pagination applied) assert.Contains(t, lines[0], "Found 3 files") // total count before pagination assert.Contains(t, lines[1], "banana.txt") }) t.Run("content mode sorted and paginated", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "content", "head_limit": 3}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 3, len(lines)) assert.Contains(t, lines[0], "apple.txt") }) t.Run("content mode with offset", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "content", "offset": 2, "head_limit": 2}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 2, len(lines)) }) t.Run("count mode sorted", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "count"}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 5, len(lines)) // 3 file counts + 1 empty line + 1 summary line assert.Contains(t, lines[0], "apple.txt:2") assert.Contains(t, lines[1], "banana.txt:3") assert.Contains(t, lines[2], "zebra.txt:3") assert.Contains(t, lines[4], "Found 8 total occurrences across 3 files.") }) t.Run("count mode with pagination", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "count", "offset": 1, "head_limit": 1}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 3, len(lines)) // 1 file count + 1 empty line + 1 summary line assert.Contains(t, lines[0], "banana.txt:3") assert.Contains(t, lines[2], "Found 8 total occurrences across 3 files.") // summary shows total before pagination }) t.Run("offset exceeds result count", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "files_with_matches", "offset": 100}`) assert.NoError(t, err) assert.Contains(t, result, "Found 3 files") // still shows total count }) t.Run("negative offset treated as zero", func(t *testing.T) { result, err := invokeTool(t, grepTool, `{"pattern": "match", "output_mode": "files_with_matches", "offset": -5}`) assert.NoError(t, err) lines := strings.Split(strings.TrimSpace(result), "\n") assert.Equal(t, 4, len(lines)) // 1 summary + 3 files }) } func TestApplyPagination(t *testing.T) { t.Run("basic pagination", func(t *testing.T) { items := []string{"a", "b", "c", "d", "e"} result := applyPagination(items, 0, 3) assert.Equal(t, []string{"a", "b", "c"}, result) }) t.Run("with offset", func(t *testing.T) { items := []string{"a", "b", "c", "d", "e"} result := applyPagination(items, 2, 2) assert.Equal(t, []string{"c", "d"}, result) }) t.Run("offset exceeds length", func(t *testing.T) { items := []string{"a", "b", "c"} result := applyPagination(items, 10, 5) assert.Equal(t, []string{}, result) }) t.Run("negative offset", func(t *testing.T) { items := []string{"a", "b", "c"} result := applyPagination(items, -1, 2) assert.Equal(t, []string{"a", "b"}, result) }) t.Run("zero head limit means no limit", func(t *testing.T) { items := []string{"a", "b", "c", "d", "e"} result := applyPagination(items, 1, 0) assert.Equal(t, []string{"b", "c", "d", "e"}, result) }) } func TestCustomToolNames(t *testing.T) { backend := setupTestBackend() ctx := context.Background() t.Run("custom tool names applied to individual tools", func(t *testing.T) { customLsName := "list_files" customReadName := "read" customWriteName := "write" customEditName := "edit" customGlobName := "find_files" customGrepName := "search" lsTool, err := newLsTool(backend, customLsName, "") assert.NoError(t, err) info, _ := lsTool.Info(ctx) assert.Equal(t, "list_files", info.Name) readTool, err := newReadFileTool(backend, customReadName, "") assert.NoError(t, err) info, _ = readTool.Info(ctx) assert.Equal(t, "read", info.Name) writeTool, err := newWriteFileTool(backend, customWriteName, "") assert.NoError(t, err) info, _ = writeTool.Info(ctx) assert.Equal(t, "write", info.Name) editTool, err := newEditFileTool(backend, customEditName, "") assert.NoError(t, err) info, _ = editTool.Info(ctx) assert.Equal(t, "edit", info.Name) globTool, err := newGlobTool(backend, customGlobName, "") assert.NoError(t, err) info, _ = globTool.Info(ctx) assert.Equal(t, "find_files", info.Name) grepTool, err := newGrepTool(backend, customGrepName, "") assert.NoError(t, err) info, _ = grepTool.Info(ctx) assert.Equal(t, "search", info.Name) }) t.Run("default tool names when custom names not provided", func(t *testing.T) { lsTool, err := newLsTool(backend, "", "") assert.NoError(t, err) info, _ := lsTool.Info(ctx) assert.Equal(t, ToolNameLs, info.Name) readTool, err := newReadFileTool(backend, "", "") assert.NoError(t, err) info, _ = readTool.Info(ctx) assert.Equal(t, ToolNameReadFile, info.Name) writeTool, err := newWriteFileTool(backend, "", "") assert.NoError(t, err) info, _ = writeTool.Info(ctx) assert.Equal(t, ToolNameWriteFile, info.Name) editTool, err := newEditFileTool(backend, "", "") assert.NoError(t, err) info, _ = editTool.Info(ctx) assert.Equal(t, ToolNameEditFile, info.Name) globTool, err := newGlobTool(backend, "", "") assert.NoError(t, err) info, _ = globTool.Info(ctx) assert.Equal(t, ToolNameGlob, info.Name) grepTool, err := newGrepTool(backend, "", "") assert.NoError(t, err) info, _ = grepTool.Info(ctx) assert.Equal(t, ToolNameGrep, info.Name) }) t.Run("custom execute tool name", func(t *testing.T) { customExecuteName := "run_command" shellBackend := &mockShellBackend{ Backend: backend, resp: &filesystem.ExecuteResponse{Output: "ok"}, } executeTool, err := newExecuteTool(shellBackend, customExecuteName, "") assert.NoError(t, err) info, _ := executeTool.Info(ctx) assert.Equal(t, "run_command", info.Name) }) t.Run("custom tool names via ToolConfig in getFilesystemTools", func(t *testing.T) { customLsName := "list_files" customReadName := "read" tools, err := getFilesystemTools(ctx, &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Name: customLsName, }, ReadFileToolConfig: &ToolConfig{ Name: customReadName, }, }) assert.NoError(t, err) toolNames := make(map[string]bool) for _, to := range tools { info, _ := to.Info(ctx) toolNames[info.Name] = true } assert.True(t, toolNames["list_files"]) assert.True(t, toolNames["read"]) }) t.Run("custom tool names via ToolConfig in middleware", func(t *testing.T) { customLsName := "list_files" customReadName := "read" m, err := New(ctx, &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Name: customLsName, }, ReadFileToolConfig: &ToolConfig{ Name: customReadName, }, }) assert.NoError(t, err) fm, ok := m.(*filesystemMiddleware) assert.True(t, ok) toolNames := make(map[string]bool) for _, to := range fm.additionalTools { info, _ := to.Info(ctx) toolNames[info.Name] = true } assert.True(t, toolNames["list_files"]) assert.True(t, toolNames["read"]) }) } func TestSelectToolName(t *testing.T) { t.Run("returns custom name when provided", func(t *testing.T) { customName := "custom_tool" result := selectToolName(customName, "default_tool") assert.Equal(t, "custom_tool", result) }) t.Run("returns default name when custom name is nil", func(t *testing.T) { result := selectToolName("", "default_tool") assert.Equal(t, "default_tool", result) }) } func TestGetOrCreateTool(t *testing.T) { backend := setupTestBackend() t.Run("returns custom tool when provided", func(t *testing.T) { customTool, err := newLsTool(backend, "", "") assert.NoError(t, err) result, err := getOrCreateTool(customTool, func() (tool.BaseTool, error) { t.Fatal("createFunc should not be called when custom tool is provided") return nil, nil }) assert.NoError(t, err) assert.Equal(t, customTool, result) }) t.Run("calls createFunc when custom tool is nil", func(t *testing.T) { expectedTool, err := newReadFileTool(backend, "", "") assert.NoError(t, err) createFuncCalled := false result, err := getOrCreateTool(nil, func() (tool.BaseTool, error) { createFuncCalled = true return expectedTool, nil }) assert.NoError(t, err) assert.True(t, createFuncCalled, "createFunc should be called when custom tool is nil") assert.Equal(t, expectedTool, result) }) t.Run("returns nil when custom tool is nil and createFunc returns nil", func(t *testing.T) { result, err := getOrCreateTool(nil, func() (tool.BaseTool, error) { return nil, nil }) assert.NoError(t, err) assert.Nil(t, result) }) t.Run("propagates error from createFunc", func(t *testing.T) { expectedErr := assert.AnError result, err := getOrCreateTool(nil, func() (tool.BaseTool, error) { return nil, expectedErr }) assert.Error(t, err) assert.Equal(t, expectedErr, err) assert.Nil(t, result) }) } func TestCustomTools(t *testing.T) { backend := setupTestBackend() ctx := context.Background() t.Run("custom ls tool is used via ToolConfig", func(t *testing.T) { customLsTool, err := newLsTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ LsToolConfig: &ToolConfig{ CustomTool: customLsTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 1) assert.Equal(t, customLsTool, tools[0]) }) t.Run("custom read file tool is used via ToolConfig", func(t *testing.T) { customReadTool, err := newReadFileTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ ReadFileToolConfig: &ToolConfig{ CustomTool: customReadTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 1) assert.Equal(t, customReadTool, tools[0]) }) t.Run("custom write file tool is used via ToolConfig", func(t *testing.T) { customWriteTool, err := newWriteFileTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ WriteFileToolConfig: &ToolConfig{ CustomTool: customWriteTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 1) assert.Equal(t, customWriteTool, tools[0]) }) t.Run("custom edit file tool is used via ToolConfig", func(t *testing.T) { customEditTool, err := newEditFileTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ EditFileToolConfig: &ToolConfig{ CustomTool: customEditTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 1) assert.Equal(t, customEditTool, tools[0]) }) t.Run("custom glob tool is used via ToolConfig", func(t *testing.T) { customGlobTool, err := newGlobTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ GlobToolConfig: &ToolConfig{ CustomTool: customGlobTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 1) assert.Equal(t, customGlobTool, tools[0]) }) t.Run("custom grep tool is used via ToolConfig", func(t *testing.T) { customGrepTool, err := newGrepTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ GrepToolConfig: &ToolConfig{ CustomTool: customGrepTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 1) assert.Equal(t, customGrepTool, tools[0]) }) t.Run("multiple custom tools can be used together", func(t *testing.T) { customLsTool, err := newLsTool(backend, "", "") assert.NoError(t, err) customReadTool, err := newReadFileTool(backend, "", "") assert.NoError(t, err) customGlobTool, err := newGlobTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ LsToolConfig: &ToolConfig{ CustomTool: customLsTool, }, ReadFileToolConfig: &ToolConfig{ CustomTool: customReadTool, }, GlobToolConfig: &ToolConfig{ CustomTool: customGlobTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 3) }) t.Run("custom tools take precedence over backend", func(t *testing.T) { customLsTool, err := newLsTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ CustomTool: customLsTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) lsToolFound := false for _, t := range tools { if t == customLsTool { lsToolFound = true break } } assert.True(t, lsToolFound, "custom ls tool should be in the tools list") }) t.Run("backend tools are created when custom tools not provided", func(t *testing.T) { config := &MiddlewareConfig{ Backend: backend, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Greater(t, len(tools), 0, "should create backend tools when custom tools not provided") }) } func TestToolConfig(t *testing.T) { backend := setupTestBackend() ctx := context.Background() t.Run("use new ToolConfig", func(t *testing.T) { customName := "my_ls" config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Name: customName, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 6) var lsToolFound bool for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == "my_ls" { lsToolFound = true break } } assert.True(t, lsToolFound) }) t.Run("ToolConfig disabled", func(t *testing.T) { config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Disable: true, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 5) for _, tool := range tools { info, _ := tool.Info(ctx) assert.NotEqual(t, ToolNameLs, info.Name) } }) t.Run("ToolConfig with custom tool", func(t *testing.T) { customLsTool, err := newLsTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ CustomTool: customLsTool, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var lsToolFound bool for _, tool := range tools { if tool == customLsTool { lsToolFound = true break } } assert.True(t, lsToolFound) }) t.Run("ToolConfig Desc takes precedence over legacy Desc", func(t *testing.T) { customDesc := "new description" legacyDesc := "old description" config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Desc: &customDesc, }, CustomLsToolDesc: &legacyDesc, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var found bool for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == ToolNameLs && info.Desc == "new description" { found = true break } } assert.True(t, found) }) t.Run("legacy Desc field still works", func(t *testing.T) { legacyDesc := "legacy description" config := &MiddlewareConfig{ Backend: backend, CustomLsToolDesc: &legacyDesc, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var found bool for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == ToolNameLs && info.Desc == "legacy description" { found = true break } } assert.True(t, found) }) t.Run("multiple ToolConfig", func(t *testing.T) { lsName := "my_ls" readName := "my_read" config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Name: lsName, }, ReadFileToolConfig: &ToolConfig{ Name: readName, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) toolNames := make(map[string]bool) for _, tool := range tools { info, _ := tool.Info(ctx) toolNames[info.Name] = true } assert.True(t, toolNames["my_ls"]) assert.True(t, toolNames["my_read"]) }) } func TestToolConfigEdgeCases(t *testing.T) { backend := setupTestBackend() ctx := context.Background() t.Run("nil ToolConfig.Desc with nil legacyDesc", func(t *testing.T) { config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Desc: nil, }, CustomLsToolDesc: nil, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var lsTool tool.BaseTool for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == ToolNameLs { lsTool = tool break } } assert.NotNil(t, lsTool, "ls tool should be created even with nil Desc") }) t.Run("nil ToolConfig.Desc falls back to legacyDesc", func(t *testing.T) { legacyDesc := "legacy description from pointer" config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Desc: nil, }, CustomLsToolDesc: &legacyDesc, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var found bool for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == ToolNameLs && info.Desc == "legacy description from pointer" { found = true break } } assert.True(t, found, "nil ToolConfig.Desc should fall back to legacyDesc") }) t.Run("CustomTool with Disable flag should not create tool", func(t *testing.T) { customLsTool, err := newLsTool(backend, "", "") assert.NoError(t, err) config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ CustomTool: customLsTool, Disable: true, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) for _, tool := range tools { info, _ := tool.Info(ctx) assert.NotEqual(t, ToolNameLs, info.Name, "disabled tool should not be created even if CustomTool is set") } }) t.Run("multiple ToolConfig with conflicting settings", func(t *testing.T) { legacyDesc := "legacy ls desc" customDesc := "custom desc" config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Name: "custom_ls", Desc: &customDesc, Disable: false, }, CustomLsToolDesc: &legacyDesc, ReadFileToolConfig: &ToolConfig{ Disable: true, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) hasLsTool := false hasReadTool := false for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == "custom_ls" { hasLsTool = true assert.Equal(t, "custom desc", info.Desc, "ToolConfig.Desc should take precedence over legacy") } if info.Name == ToolNameReadFile { hasReadTool = true } } assert.True(t, hasLsTool, "ls tool should be created") assert.False(t, hasReadTool, "read_file tool should be disabled") }) t.Run("nil ToolConfig with nil legacyDesc creates default tool", func(t *testing.T) { config := &MiddlewareConfig{ Backend: backend, LsToolConfig: nil, CustomLsToolDesc: nil, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var lsTool tool.BaseTool for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == ToolNameLs { lsTool = tool break } } assert.NotNil(t, lsTool, "tool should be created with backend even when config is nil") }) t.Run("empty Name in ToolConfig uses default name", func(t *testing.T) { config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ Name: "", }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var lsTool tool.BaseTool for _, tool := range tools { info, _ := tool.Info(ctx) if info.Name == ToolNameLs { lsTool = tool break } } assert.NotNil(t, lsTool, "tool should use default name when Name is empty") }) } func TestGetFilesystemTools_DisableAllTools(t *testing.T) { ctx := context.Background() backend := setupTestBackend() config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{Disable: true}, ReadFileToolConfig: &ToolConfig{Disable: true}, WriteFileToolConfig: &ToolConfig{Disable: true}, EditFileToolConfig: &ToolConfig{Disable: true}, GlobToolConfig: &ToolConfig{Disable: true}, GrepToolConfig: &ToolConfig{Disable: true}, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 0) } func TestGetFilesystemTools_StreamingShell(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("returns 7 tools with StreamingShell", func(t *testing.T) { mockSS := &mockStreamingShell{} tools, err := getFilesystemTools(ctx, &MiddlewareConfig{ Backend: backend, StreamingShell: mockSS, }) assert.NoError(t, err) assert.Len(t, tools, 7) toolNames := make([]string, 0, len(tools)) for _, to := range tools { info, _ := to.Info(ctx) toolNames = append(toolNames, info.Name) } assert.Contains(t, toolNames, ToolNameExecute) }) t.Run("StreamingShell takes precedence over Shell", func(t *testing.T) { mockSS := &mockStreamingShell{} shellBackend := &mockShellBackend{ Backend: backend, resp: &filesystem.ExecuteResponse{Output: "ok"}, } // When both are set, Validate should fail config := &MiddlewareConfig{ Backend: backend, Shell: shellBackend, StreamingShell: mockSS, } err := config.Validate() assert.Error(t, err) assert.Contains(t, err.Error(), "shell and streaming shell should not be both set") }) } func TestGetFilesystemTools_NilBackend(t *testing.T) { ctx := context.Background() t.Run("nil backend with shell only returns execute tool", func(t *testing.T) { mockSS := &mockStreamingShell{} config := &MiddlewareConfig{ Backend: nil, StreamingShell: mockSS, } // Validate should fail, but getFilesystemTools itself handles nil backend gracefully tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) // Only execute tool should be returned since backend is nil assert.Len(t, tools, 1) info, _ := tools[0].Info(ctx) assert.Equal(t, ToolNameExecute, info.Name) }) t.Run("nil backend with regular Shell returns execute tool", func(t *testing.T) { mockShell := &mockShellBackend{ resp: &filesystem.ExecuteResponse{Output: "ok"}, } config := &MiddlewareConfig{ Backend: nil, Shell: mockShell, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 1) info, _ := tools[0].Info(ctx) assert.Equal(t, ToolNameExecute, info.Name) }) t.Run("nil backend and nil shell returns empty tools", func(t *testing.T) { config := &MiddlewareConfig{ Backend: nil, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 0) }) } func TestGetFilesystemTools_PartialDisable(t *testing.T) { ctx := context.Background() backend := setupTestBackend() config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{Disable: true}, ReadFileToolConfig: &ToolConfig{Disable: true}, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) assert.Len(t, tools, 4) toolNames := make([]string, 0, len(tools)) for _, to := range tools { info, _ := to.Info(ctx) toolNames = append(toolNames, info.Name) } assert.NotContains(t, toolNames, ToolNameLs) assert.NotContains(t, toolNames, ToolNameReadFile) assert.Contains(t, toolNames, ToolNameWriteFile) assert.Contains(t, toolNames, ToolNameEditFile) assert.Contains(t, toolNames, ToolNameGlob) assert.Contains(t, toolNames, ToolNameGrep) } type mockStreamingShell struct{} func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(&filesystem.ExecuteResponse{ Output: "streaming output", ExitCode: ptrOf(0), }, nil) }() return sr, nil } type mockStreamingShellWithError struct{} func (m *mockStreamingShellWithError) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { return nil, fmt.Errorf("streaming shell error") } type mockStreamingShellWithRecvError struct{} func (m *mockStreamingShellWithRecvError) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(nil, fmt.Errorf("recv error during streaming")) }() return sr, nil } type mockStreamingShellWithExitCode struct { exitCode int } func (m *mockStreamingShellWithExitCode) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(&filesystem.ExecuteResponse{ Output: "some output", ExitCode: ptrOf(m.exitCode), }, nil) }() return sr, nil } type mockStreamingShellNoOutput struct{} func (m *mockStreamingShellNoOutput) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(&filesystem.ExecuteResponse{ ExitCode: ptrOf(0), }, nil) }() return sr, nil } type mockStreamingShellTruncated struct{} func (m *mockStreamingShellTruncated) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(&filesystem.ExecuteResponse{ Output: "partial", Truncated: true, ExitCode: ptrOf(0), }, nil) }() return sr, nil } type mockStreamingShellNilChunk struct{} func (m *mockStreamingShellNilChunk) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(nil, nil) sw.Send(&filesystem.ExecuteResponse{ Output: "after nil", ExitCode: ptrOf(0), }, nil) }() return sr, nil } func TestNewStreamingExecuteTool(t *testing.T) { t.Run("successful streaming execution", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShell{}, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "echo hello"}`) assert.NoError(t, err) defer sr.Close() var chunks []string for { chunk, recvErr := sr.Recv() if recvErr == io.EOF { break } assert.NoError(t, recvErr) chunks = append(chunks, chunk) } assert.True(t, len(chunks) > 0) result := strings.Join(chunks, "") assert.Contains(t, result, "streaming output") }) t.Run("streaming execution with ExecuteStreaming error", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShellWithError{}, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) _, err = st.StreamableRun(context.Background(), `{"command": "fail"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "streaming shell error") }) t.Run("streaming execution with recv error", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShellWithRecvError{}, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "echo hello"}`) assert.NoError(t, err) defer sr.Close() var gotError bool for { _, recvErr := sr.Recv() if recvErr == io.EOF { break } if recvErr != nil { gotError = true assert.Contains(t, recvErr.Error(), "recv error during streaming") break } } assert.True(t, gotError) }) t.Run("streaming execution with non-zero exit code", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShellWithExitCode{exitCode: 1}, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "false"}`) assert.NoError(t, err) defer sr.Close() var chunks []string for { chunk, recvErr := sr.Recv() if recvErr == io.EOF { break } assert.NoError(t, recvErr) chunks = append(chunks, chunk) } result := strings.Join(chunks, "") assert.Contains(t, result, "[Command failed with exit code 1]") }) t.Run("streaming execution with zero exit code and no output", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShellNoOutput{}, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "true"}`) assert.NoError(t, err) defer sr.Close() var chunks []string for { chunk, recvErr := sr.Recv() if recvErr == io.EOF { break } assert.NoError(t, recvErr) chunks = append(chunks, chunk) } result := strings.Join(chunks, "") assert.Contains(t, result, "[Command executed successfully with no output]") }) t.Run("streaming execution with truncated output", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShellTruncated{}, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "cat largefile"}`) assert.NoError(t, err) defer sr.Close() var chunks []string for { chunk, recvErr := sr.Recv() if recvErr == io.EOF { break } assert.NoError(t, recvErr) chunks = append(chunks, chunk) } result := strings.Join(chunks, "") assert.Contains(t, result, "partial") assert.Contains(t, result, "[Output was truncated due to size limits]") }) t.Run("streaming execution with nil chunk skipped", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShellNilChunk{}, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "echo test"}`) assert.NoError(t, err) defer sr.Close() var chunks []string for { chunk, recvErr := sr.Recv() if recvErr == io.EOF { break } assert.NoError(t, recvErr) chunks = append(chunks, chunk) } result := strings.Join(chunks, "") assert.Contains(t, result, "after nil") }) t.Run("streaming execution with custom name and desc", func(t *testing.T) { executeTool, err := newStreamingExecuteTool(&mockStreamingShell{}, "custom_execute", "custom desc") assert.NoError(t, err) info, err := executeTool.Info(context.Background()) assert.NoError(t, err) assert.Equal(t, "custom_execute", info.Name) assert.Equal(t, "custom desc", info.Desc) }) } func TestNew_StreamingShell(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("StreamingShell adds streaming execute tool", func(t *testing.T) { m, err := New(ctx, &MiddlewareConfig{ Backend: backend, StreamingShell: &mockStreamingShell{}, }) assert.NoError(t, err) fm, ok := m.(*filesystemMiddleware) assert.True(t, ok) assert.Len(t, fm.additionalTools, 7) }) t.Run("both Shell and StreamingShell returns error", func(t *testing.T) { _, err := New(ctx, &MiddlewareConfig{ Backend: backend, Shell: &mockShellBackend{Backend: backend, resp: &filesystem.ExecuteResponse{Output: "ok"}}, StreamingShell: &mockStreamingShell{}, }) assert.Error(t, err) assert.Contains(t, err.Error(), "shell and streaming shell should not be both set") }) } func TestNewMiddleware_Validation(t *testing.T) { ctx := context.Background() t.Run("nil config returns error", func(t *testing.T) { _, err := NewMiddleware(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "config should not be nil") }) t.Run("nil backend returns error", func(t *testing.T) { _, err := NewMiddleware(ctx, &Config{Backend: nil}) assert.Error(t, err) assert.Contains(t, err.Error(), "backend should not be nil") }) t.Run("both Shell and StreamingShell returns error", func(t *testing.T) { backend := setupTestBackend() _, err := NewMiddleware(ctx, &Config{ Backend: backend, Shell: &mockShellBackend{Backend: backend, resp: &filesystem.ExecuteResponse{Output: "ok"}}, StreamingShell: &mockStreamingShell{}, }) assert.Error(t, err) assert.Contains(t, err.Error(), "shell and streaming shell should not be both set") }) } func TestMiddlewareConfig_Validate(t *testing.T) { t.Run("nil config returns error", func(t *testing.T) { var c *MiddlewareConfig err := c.Validate() assert.Error(t, err) assert.Contains(t, err.Error(), "config should not be nil") }) t.Run("nil backend returns error", func(t *testing.T) { c := &MiddlewareConfig{} err := c.Validate() assert.Error(t, err) assert.Contains(t, err.Error(), "backend should not be nil") }) t.Run("both shells returns error", func(t *testing.T) { c := &MiddlewareConfig{ Backend: setupTestBackend(), Shell: &mockShellBackend{}, StreamingShell: &mockStreamingShell{}, } err := c.Validate() assert.Error(t, err) assert.Contains(t, err.Error(), "shell and streaming shell should not be both set") }) t.Run("valid config passes", func(t *testing.T) { c := &MiddlewareConfig{ Backend: setupTestBackend(), } err := c.Validate() assert.NoError(t, err) }) } func TestNewStreamingExecuteTool_MultipleChunks(t *testing.T) { mockSS := &mockStreamingShellMultiChunk{} executeTool, err := newStreamingExecuteTool(mockSS, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "long-running"}`) assert.NoError(t, err) defer sr.Close() var chunks []string for { chunk, recvErr := sr.Recv() if recvErr == io.EOF { break } assert.NoError(t, recvErr) chunks = append(chunks, chunk) } // Should have received multiple chunks assert.True(t, len(chunks) >= 3) result := strings.Join(chunks, "") assert.Contains(t, result, "chunk1") assert.Contains(t, result, "chunk2") assert.Contains(t, result, "chunk3") } type mockStreamingShellMultiChunk struct{} func (m *mockStreamingShellMultiChunk) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(&filesystem.ExecuteResponse{Output: "chunk1\n"}, nil) sw.Send(&filesystem.ExecuteResponse{Output: "chunk2\n"}, nil) sw.Send(&filesystem.ExecuteResponse{Output: "chunk3\n", ExitCode: ptrOf(0)}, nil) }() return sr, nil } func TestNewStreamingExecuteTool_ExitCodeOnlyInLastChunk(t *testing.T) { mockSS := &mockStreamingShellExitCodeLast{exitCode: 2} executeTool, err := newStreamingExecuteTool(mockSS, "", "") assert.NoError(t, err) st := executeTool.(tool.StreamableTool) sr, err := st.StreamableRun(context.Background(), `{"command": "fail-at-end"}`) assert.NoError(t, err) defer sr.Close() var chunks []string for { chunk, recvErr := sr.Recv() if recvErr == io.EOF { break } assert.NoError(t, recvErr) chunks = append(chunks, chunk) } result := strings.Join(chunks, "") assert.Contains(t, result, "output line") assert.Contains(t, result, "[Command failed with exit code 2]") } type mockStreamingShellExitCodeLast struct { exitCode int } func (m *mockStreamingShellExitCodeLast) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { sr, sw := schema.Pipe[*filesystem.ExecuteResponse](10) go func() { defer sw.Close() sw.Send(&filesystem.ExecuteResponse{Output: "output line"}, nil) sw.Send(&filesystem.ExecuteResponse{ExitCode: ptrOf(m.exitCode)}, nil) }() return sr, nil } func TestConvExecuteResponse_NilResponse(t *testing.T) { result := convExecuteResponse(nil) assert.Equal(t, "", result) } func TestConvExecuteResponse_NilExitCode(t *testing.T) { result := convExecuteResponse(&filesystem.ExecuteResponse{ Output: "some output", }) assert.Equal(t, "some output", result) } func TestConfig_Validate(t *testing.T) { t.Run("nil config returns error", func(t *testing.T) { var c *Config err := c.Validate() assert.Error(t, err) }) t.Run("nil backend returns error", func(t *testing.T) { c := &Config{} err := c.Validate() assert.Error(t, err) assert.Contains(t, err.Error(), "backend should not be nil") }) t.Run("both shells returns error", func(t *testing.T) { c := &Config{ Backend: setupTestBackend(), Shell: &mockShellBackend{}, StreamingShell: &mockStreamingShell{}, } err := c.Validate() assert.Error(t, err) }) t.Run("valid config passes", func(t *testing.T) { c := &Config{ Backend: setupTestBackend(), } err := c.Validate() assert.NoError(t, err) }) } func TestGetFilesystemTools_CustomToolWithShell(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("custom tool replaces default for all disabled except custom", func(t *testing.T) { customLs, err := newLsTool(backend, "my_ls", "my ls desc") assert.NoError(t, err) config := &MiddlewareConfig{ Backend: backend, LsToolConfig: &ToolConfig{ CustomTool: customLs, }, } tools, err := getFilesystemTools(ctx, config) assert.NoError(t, err) var found bool for _, to := range tools { info, _ := to.Info(ctx) if info.Name == "my_ls" { found = true break } } assert.True(t, found) }) } func TestMergeToolConfigWithDesc(t *testing.T) { config := &MiddlewareConfig{Backend: setupTestBackend()} t.Run("both nil returns empty ToolConfig", func(t *testing.T) { result := config.mergeToolConfigWithDesc(nil, nil) assert.NotNil(t, result) assert.Equal(t, "", result.Name) assert.Nil(t, result.Desc) assert.False(t, result.Disable) }) t.Run("nil toolConfig with legacyDesc", func(t *testing.T) { desc := "legacy" result := config.mergeToolConfigWithDesc(nil, &desc) assert.NotNil(t, result) assert.Equal(t, "legacy", *result.Desc) }) t.Run("toolConfig with Desc overrides legacyDesc", func(t *testing.T) { tcDesc := "tc desc" legacyDesc := "legacy" tc := &ToolConfig{Desc: &tcDesc} result := config.mergeToolConfigWithDesc(tc, &legacyDesc) assert.Equal(t, "tc desc", *result.Desc) }) t.Run("toolConfig with nil Desc falls back to legacyDesc", func(t *testing.T) { legacyDesc := "legacy" tc := &ToolConfig{Name: "custom"} result := config.mergeToolConfigWithDesc(tc, &legacyDesc) assert.Equal(t, "legacy", *result.Desc) assert.Equal(t, "custom", result.Name) }) t.Run("toolConfig with nil Desc and nil legacyDesc", func(t *testing.T) { tc := &ToolConfig{Name: "custom"} result := config.mergeToolConfigWithDesc(tc, nil) assert.Nil(t, result.Desc) assert.Equal(t, "custom", result.Name) }) } func TestNewMiddleware_WithShell(t *testing.T) { ctx := context.Background() backend := setupTestBackend() t.Run("Shell backend creates execute tool", func(t *testing.T) { shellBackend := &mockShellBackend{ Backend: backend, resp: &filesystem.ExecuteResponse{Output: "ok"}, } m, err := NewMiddleware(ctx, &Config{ Backend: backend, Shell: shellBackend, }) assert.NoError(t, err) assert.Len(t, m.AdditionalTools, 7) }) t.Run("StreamingShell backend creates streaming execute tool", func(t *testing.T) { m, err := NewMiddleware(ctx, &Config{ Backend: backend, StreamingShell: &mockStreamingShell{}, }) assert.NoError(t, err) assert.Len(t, m.AdditionalTools, 7) }) } func TestNewExecuteTool_ShellError(t *testing.T) { mockShell := &mockShellBackendWithError{} executeTool, err := newExecuteTool(mockShell, "", "") assert.NoError(t, err) result, err := invokeTool(t, executeTool, `{"command": "fail"}`) assert.Error(t, err) assert.Equal(t, "", result) assert.Contains(t, err.Error(), "shell execution error") } type mockShellBackendWithError struct{} func (m *mockShellBackendWithError) Execute(ctx context.Context, req *filesystem.ExecuteRequest) (*filesystem.ExecuteResponse, error) { return nil, errors.New("shell execution error") } ================================================ FILE: adk/middlewares/filesystem/large_tool_result.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filesystem import ( "bufio" "context" "errors" "fmt" "io" "strings" "unicode/utf8" "github.com/slongfield/pyfmt" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) type toolResultOffloadingConfig struct { Backend filesystem.Backend TokenLimit int PathGenerator func(ctx context.Context, input *compose.ToolInput) (string, error) } func newToolResultOffloading(ctx context.Context, config *toolResultOffloadingConfig) compose.ToolMiddleware { offloading := &toolResultOffloading{ backend: config.Backend, tokenLimit: config.TokenLimit, pathGenerator: config.PathGenerator, } if offloading.tokenLimit == 0 { offloading.tokenLimit = 20000 } if offloading.pathGenerator == nil { offloading.pathGenerator = func(ctx context.Context, input *compose.ToolInput) (string, error) { return fmt.Sprintf("/large_tool_result/%s", input.CallID), nil } } return compose.ToolMiddleware{ Invokable: offloading.invoke, Streamable: offloading.stream, } } type toolResultOffloading struct { backend filesystem.Backend tokenLimit int pathGenerator func(ctx context.Context, input *compose.ToolInput) (string, error) } func (t *toolResultOffloading) invoke(endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { output, err := endpoint(ctx, input) if err != nil { return nil, err } result, err := t.handleResult(ctx, output.Result, input) if err != nil { return nil, err } return &compose.ToolOutput{Result: result}, nil } } func (t *toolResultOffloading) stream(endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { output, err := endpoint(ctx, input) if err != nil { return nil, err } result, err := concatString(output.Result) if err != nil { return nil, err } result, err = t.handleResult(ctx, result, input) if err != nil { return nil, err } return &compose.StreamToolOutput{Result: schema.StreamReaderFromArray([]string{result})}, nil } } func (t *toolResultOffloading) handleResult(ctx context.Context, result string, input *compose.ToolInput) (string, error) { if len(result) > t.tokenLimit*4 { path, err := t.pathGenerator(ctx, input) if err != nil { return "", err } nResult := formatToolMessage(result) msgTemplate := internal.SelectPrompt(internal.I18nPrompts{ English: tooLargeToolMessage, Chinese: tooLargeToolMessageChinese, }) nResult, err = pyfmt.Fmt(msgTemplate, map[string]any{ "tool_call_id": input.CallID, "file_path": path, "content_sample": nResult, }) if err != nil { return "", err } err = t.backend.Write(ctx, &WriteRequest{ FilePath: path, Content: result, }) if err != nil { return "", err } return nResult, nil } return result, nil } func concatString(sr *schema.StreamReader[string]) (string, error) { if sr == nil { return "", errors.New("stream is nil") } sb := strings.Builder{} for { str, err := sr.Recv() if errors.Is(err, io.EOF) { return sb.String(), nil } if err != nil { return "", err } sb.WriteString(str) } } func formatToolMessage(s string) string { reader := bufio.NewScanner(strings.NewReader(s)) var b strings.Builder lineNum := 1 for reader.Scan() { if lineNum > 10 { break } line := reader.Text() if utf8.RuneCountInString(line) > 1000 { runes := []rune(line) line = string(runes[:1000]) } b.WriteString(fmt.Sprintf("%d: %s\n", lineNum, line)) lineNum++ } return b.String() } ================================================ FILE: adk/middlewares/filesystem/large_tool_result_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filesystem import ( "context" "errors" "fmt" "io" "strings" "testing" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // mockBackend is a simple in-memory backend for testing type mockBackend struct { files map[string]string } func newMockBackend() *mockBackend { return &mockBackend{ files: make(map[string]string), } } func (m *mockBackend) Write(ctx context.Context, req *WriteRequest) error { m.files[req.FilePath] = req.Content return nil } func (m *mockBackend) Read(ctx context.Context, req *ReadRequest) (*FileContent, error) { content, ok := m.files[req.FilePath] if !ok { return nil, errors.New("file not found") } return &FileContent{Content: content}, nil } func (m *mockBackend) LsInfo(ctx context.Context, _ *LsInfoRequest) ([]FileInfo, error) { return nil, nil } func (m *mockBackend) GrepRaw(ctx context.Context, _ *GrepRequest) ([]GrepMatch, error) { return nil, nil } func (m *mockBackend) GlobInfo(ctx context.Context, _ *GlobInfoRequest) ([]FileInfo, error) { return nil, nil } func (m *mockBackend) Edit(ctx context.Context, _ *EditRequest) error { return nil } func TestToolResultOffloading_SmallResult(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 100, // Small limit for testing } middleware := newToolResultOffloading(ctx, config) // Create a mock endpoint that returns a small result smallResult := "This is a small result" mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: smallResult}, nil } // Wrap the endpoint with the middleware wrappedEndpoint := middleware.Invokable(mockEndpoint) // Execute input := &compose.ToolInput{ Name: "test_tool", CallID: "call_123", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Small result should pass through unchanged if output.Result != smallResult { t.Errorf("expected result %q, got %q", smallResult, output.Result) } // No file should be written if len(backend.files) != 0 { t.Errorf("expected no files to be written, got %d files", len(backend.files)) } } func TestToolResultOffloading_LargeResult(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, // Very small limit to trigger offloading } middleware := newToolResultOffloading(ctx, config) // Create a large result (more than 10 * 4 = 40 bytes) largeResult := strings.Repeat("This is a long line of text that will exceed the token limit.\n", 10) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_456", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Result should be replaced with a message if !strings.Contains(output.Result, "Tool result too large") { t.Errorf("expected result to contain 'Tool result too large', got %q", output.Result) } if !strings.Contains(output.Result, "call_456") { t.Errorf("expected result to contain call ID 'call_456', got %q", output.Result) } if !strings.Contains(output.Result, "/large_tool_result/call_456") { t.Errorf("expected result to contain file path, got %q", output.Result) } // File should be written if len(backend.files) != 1 { t.Fatalf("expected 1 file to be written, got %d files", len(backend.files)) } savedContent, ok := backend.files["/large_tool_result/call_456"] if !ok { t.Fatalf("expected file at /large_tool_result/call_456, got files: %v", backend.files) } if savedContent != largeResult { t.Errorf("saved content doesn't match original result") } } func TestToolResultOffloading_CustomPathGenerator(t *testing.T) { ctx := context.Background() backend := newMockBackend() customPath := "/custom/path/result.txt" config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, PathGenerator: func(ctx context.Context, input *compose.ToolInput) (string, error) { return customPath, nil }, } middleware := newToolResultOffloading(ctx, config) largeResult := strings.Repeat("Large content ", 100) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_789", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Check custom path is used if !strings.Contains(output.Result, customPath) { t.Errorf("expected result to contain custom path %q, got %q", customPath, output.Result) } // File should be written to custom path savedContent, ok := backend.files[customPath] if !ok { t.Fatalf("expected file at %q, got files: %v", customPath, backend.files) } if savedContent != largeResult { t.Errorf("saved content doesn't match original result") } } func TestToolResultOffloading_PathGeneratorError(t *testing.T) { ctx := context.Background() backend := newMockBackend() expectedErr := errors.New("path generation failed") config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, PathGenerator: func(ctx context.Context, input *compose.ToolInput) (string, error) { return "", expectedErr }, } middleware := newToolResultOffloading(ctx, config) largeResult := strings.Repeat("Large content ", 100) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, expectedErr) { t.Errorf("expected error %v, got %v", expectedErr, err) } } func TestToolResultOffloading_EndpointError(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 100, } middleware := newToolResultOffloading(ctx, config) expectedErr := errors.New("endpoint execution failed") mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return nil, expectedErr } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_endpoint_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, expectedErr) { t.Errorf("expected error %v, got %v", expectedErr, err) } } func TestToolResultOffloading_DefaultTokenLimit(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 0, // Should default to 20000 } middleware := newToolResultOffloading(ctx, config) // Create a result smaller than 20000 * 4 = 80000 bytes smallResult := strings.Repeat("x", 1000) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: smallResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_default", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Should pass through unchanged if output.Result != smallResult { t.Errorf("expected result to pass through unchanged") } // No file should be written if len(backend.files) != 0 { t.Errorf("expected no files to be written, got %d files", len(backend.files)) } } func TestToolResultOffloading_Stream(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, } middleware := newToolResultOffloading(ctx, config) // Create a streaming endpoint that returns large content largeResult := strings.Repeat("Large streaming content ", 100) mockStreamEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { // Split the result into chunks chunks := []string{largeResult[:len(largeResult)/2], largeResult[len(largeResult)/2:]} return &compose.StreamToolOutput{ Result: schema.StreamReaderFromArray(chunks), }, nil } wrappedEndpoint := middleware.Streamable(mockStreamEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_stream", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Read the stream var result strings.Builder for { chunk, err := output.Result.Recv() if errors.Is(err, io.EOF) { break } if err != nil { t.Fatalf("error reading stream: %v", err) } result.WriteString(chunk) } resultStr := result.String() // Result should be replaced with a message if !strings.Contains(resultStr, "Tool result too large") { t.Errorf("expected result to contain 'Tool result too large', got %q", resultStr) } if !strings.Contains(resultStr, "call_stream") { t.Errorf("expected result to contain call ID 'call_stream', got %q", resultStr) } // File should be written if len(backend.files) != 1 { t.Fatalf("expected 1 file to be written, got %d files", len(backend.files)) } savedContent, ok := backend.files["/large_tool_result/call_stream"] if !ok { t.Fatalf("expected file at /large_tool_result/call_stream, got files: %v", backend.files) } if savedContent != largeResult { t.Errorf("saved content doesn't match original result") } } func TestToolResultOffloading_StreamError(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, } middleware := newToolResultOffloading(ctx, config) expectedErr := errors.New("stream endpoint failed") mockStreamEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { return nil, expectedErr } wrappedEndpoint := middleware.Streamable(mockStreamEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_stream_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, expectedErr) { t.Errorf("expected error %v, got %v", expectedErr, err) } } func TestFormatToolMessage(t *testing.T) { tests := []struct { name string input string expected string }{ { name: "single line", input: "single line", expected: "1: single line\n", }, { name: "multiple lines", input: "line1\nline2\nline3", expected: "1: line1\n2: line2\n3: line3\n", }, { name: "more than 10 lines", input: "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12", expected: "1: 1\n2: 2\n3: 3\n4: 4\n5: 5\n6: 6\n7: 7\n8: 8\n9: 9\n10: 10\n", }, { name: "long line truncation", input: strings.Repeat("a", 1500), expected: fmt.Sprintf("1: %s\n", strings.Repeat("a", 1000)), }, { name: "unicode characters", input: "你好世界\n测试", expected: "1: 你好世界\n2: 测试\n", }, { name: "long unicode line", input: strings.Repeat("你", 1500), expected: fmt.Sprintf("1: %s\n", strings.Repeat("你", 1000)), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := formatToolMessage(tt.input) if result != tt.expected { t.Errorf("formatToolMessage() = %q, want %q", result, tt.expected) } }) } } func TestConcatString(t *testing.T) { tests := []struct { name string chunks []string expected string expectError bool }{ { name: "single chunk", chunks: []string{"hello"}, expected: "hello", }, { name: "multiple chunks", chunks: []string{"hello", " ", "world"}, expected: "hello world", }, { name: "empty chunks", chunks: []string{"", "", ""}, expected: "", }, { name: "mixed chunks", chunks: []string{"a", "", "b", "c"}, expected: "abc", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sr := schema.StreamReaderFromArray(tt.chunks) result, err := concatString(sr) if tt.expectError { if err == nil { t.Error("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if result != tt.expected { t.Errorf("concatString() = %q, want %q", result, tt.expected) } }) } // Test nil stream t.Run("nil stream", func(t *testing.T) { _, err := concatString(nil) if err == nil { t.Error("expected error for nil stream, got nil") } if !strings.Contains(err.Error(), "stream is nil") { t.Errorf("expected 'stream is nil' error, got %v", err) } }) } func TestToolResultOffloading_BackendWriteError(t *testing.T) { ctx := context.Background() // Create a backend that fails on write backend := &failingBackend{ writeErr: errors.New("write failed"), } config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, } middleware := newToolResultOffloading(ctx, config) largeResult := strings.Repeat("Large content ", 100) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_write_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !strings.Contains(err.Error(), "write failed") { t.Errorf("expected 'write failed' error, got %v", err) } } // failingBackend is a mock backend that can be configured to fail type failingBackend struct { writeErr error } func (f *failingBackend) Write(ctx context.Context, req *WriteRequest) error { if f.writeErr != nil { return f.writeErr } return nil } func (f *failingBackend) Read(ctx context.Context, req *ReadRequest) (*FileContent, error) { return &FileContent{}, nil } func (f *failingBackend) LsInfo(ctx context.Context, _ *LsInfoRequest) ([]FileInfo, error) { return nil, nil } func (f *failingBackend) GrepRaw(ctx context.Context, _ *GrepRequest) ([]GrepMatch, error) { return nil, nil } func (f *failingBackend) GlobInfo(ctx context.Context, _ *GlobInfoRequest) ([]FileInfo, error) { return nil, nil } func (f *failingBackend) Edit(ctx context.Context, _ *EditRequest) error { return nil } ================================================ FILE: adk/middlewares/filesystem/prompt.go ================================================ /* * Copyright (c) 2025 Harrison Chase * Copyright (c) 2025 CloudWeGo Authors * SPDX-License-Identifier: MIT * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filesystem // This file contains prompt templates and tool descriptions adapted from the DeepAgents project. // Original source: https://github.com/langchain-ai/deepagents // // These prompts are used under the terms of the original project's open source license. // When using this code in your own open source project, ensure compliance with the original license requirements. const ( tooLargeToolMessage = `Tool result too large, the result of this tool call {tool_call_id} was saved in the filesystem at this path: {file_path} You can read the result from the filesystem by using the read_file tool, but make sure to only read part of the result at a time. You can do this by specifying an offset and limit in the read_file tool call. For example, to read the first 100 lines, you can use the read_file tool with offset=0 and limit=100. Here are the first 10 lines of the result: {content_sample}` tooLargeToolMessageChinese = `工具结果过大,此工具调用 {tool_call_id} 的结果已保存到文件系统的以下路径:{file_path} 你可以使用 read_file 工具从文件系统读取结果,但请确保每次只读取部分结果。 你可以通过在 read_file 工具调用中指定 offset 和 limit 来实现。 例如,要读取前 100 行,你可以使用 read_file 工具,设置 offset=0 和 limit=100。 以下是结果的前 10 行: {content_sample}` ListFilesToolDesc = `Lists all files in the filesystem, filtering by directory. Usage: - The path parameter must be an absolute path, not a relative path - The ls tool will return a list of all files in the specified directory. - This is very useful for exploring the file system and finding the right file to read or edit. - You should almost ALWAYS use this tool before using the read_file or edit_file tools.` ListFilesToolDescChinese = `列出文件系统中的所有文件,按目录过滤。 使用方法: - path 参数必须是绝对路径,不能是相对路径 - ls 工具将返回指定目录中所有文件的列表 - 这对于探索文件系统和找到要读取或编辑的正确文件非常有用 - 在使用 read_file 或 edit_file 工具之前,你几乎总是应该先使用此工具` ReadFileToolDesc = `Reads a file from the filesystem. You can access any file directly by using this tool. Assume this tool is able to read all files on the machine. If the User provides a path to a file assume that path is valid. It is okay to read a file that does not exist; an error will be returned. Usage: - The file_path parameter must be an absolute path, not a relative path - By default, it reads up to 2000 lines starting from the beginning of the file - **IMPORTANT for large files and codebase exploration**: Use pagination with offset and limit parameters to avoid context overflow - First scan: read_file(path, limit=100) to see file structure - Read more sections: read_file(path, offset=100, limit=200) for next 200 lines - Only omit limit (read full file) when necessary for editing - Specify offset and limit: read_file(path, offset=0, limit=100) reads first 100 lines - Results are returned using cat -n format, with line numbers starting at 1 - You have the capability to call multiple tools in a single response. It is always better to speculatively read multiple files as a batch that are potentially useful. - If you read a file that exists but has empty contents you will receive a system reminder warning in place of file contents. - You should ALWAYS make sure a file has been read before editing it.` ReadFileToolDescChinese = `从文件系统读取文件。你可以使用此工具直接访问任何文件。 假设此工具能够读取机器上的所有文件。如果用户提供了文件路径,假设该路径是有效的。读取不存在的文件是可以的;将返回错误。 使用方法: - file_path 参数必须是绝对路径,不能是相对路径 - 默认情况下,从文件开头读取最多 2000 行 - **大文件和代码库探索的重要提示**:使用 offset 和 limit 参数进行分页,以避免上下文溢出 - 首次扫描:read_file(path, limit=100) 查看文件结构 - 读取更多部分:read_file(path, offset=100, limit=200) 读取接下来的 200 行 - 仅在编辑必要时才省略 limit(读取完整文件) - 指定 offset 和 limit:read_file(path, offset=0, limit=100) 读取前 100 行 - 结果以 cat -n 格式返回,行号从 1 开始 - 你可以在单个响应中调用多个工具。最好同时推测性地批量读取多个可能有用的文件 - 如果你读取的文件存在但内容为空,你将收到系统提醒警告而不是文件内容 - 在编辑文件之前,你应该始终确保已读取该文件` EditFileToolDesc = `Performs exact string replacements in files. Usage: - You must use your 'read_file' tool at least once in the conversation before editing. This tool will error if you attempt an edit without reading the file. - When editing text from Read tool output, ensure you preserve the exact indentation (tabs/spaces) as it appears AFTER the line number prefix. The line number prefix format is: spaces + line number + tab. Everything after that tab is the actual file content to match. Never include any part of the line number prefix in the old_string or new_string. - ALWAYS prefer editing existing files. NEVER write new files unless explicitly required. - Only use emojis if the user explicitly requests it. Avoid adding emojis to files unless asked. - The edit will FAIL if 'old_string' is not unique in the file. Either provide a larger string with more surrounding context to make it unique or use 'replace_all' to change every instance of 'old_string'. - Use 'replace_all' for replacing and renaming strings across the file. This parameter is useful if you want to rename a variable for instance.` EditFileToolDescChinese = `在文件中执行精确的字符串替换。 使用方法: - 在编辑之前,你必须在对话中至少使用一次 'read_file' 工具。如果你在未读取文件的情况下尝试编辑,此工具将报错 - 当从 Read 工具输出编辑文本时,请确保保留行号前缀之后的确切缩进(制表符/空格)。行号前缀格式为:空格 + 行号 + 制表符。制表符之后的所有内容都是要匹配的实际文件内容。永远不要在 old_string 或 new_string 中包含行号前缀的任何部分 - 始终优先编辑现有文件。除非明确要求,否则不要创建新文件 - 仅在用户明确要求时使用表情符号。除非被要求,否则避免在文件中添加表情符号 - 如果 'old_string' 在文件中不唯一,编辑将失败。要么提供包含更多上下文的更长字符串使其唯一,要么使用 'replace_all' 更改 'old_string' 的每个实例 - 使用 'replace_all' 在整个文件中替换和重命名字符串。例如,如果你想重命名变量,此参数很有用` WriteFileToolDesc = `Writes a file to the local filesystem. Usage: - This tool will overwrite the existing file if there is one at the provided path. - If this is an existing file, you MUST use the Read tool first to read the file's contents. This tool will fail if you did not read the file first. - ALWAYS prefer editing existing files in the codebase. NEVER write new files unless explicitly required. - NEVER proactively create documentation files (*.md) or README files. Only create documentation files if explicitly requested by the User. - Only use emojis if the user explicitly requests it. Avoid writing emojis to files unless asked.` WriteFileToolDescChinese = `将文件写入本地文件系统。 使用方法: - 如果提供的路径已存在文件,此工具将覆盖现有文件 - 如果这是一个现有文件,你必须先使用 Read 工具读取文件内容。如果你没有先读取文件,此工具将失败 - 始终优先编辑代码库中的现有文件。除非明确要求,否则不要创建新文件 - 不要主动创建文档文件(*.md)或 README 文件。仅在用户明确要求时才创建文档文件 - 仅在用户明确要求时使用表情符号。除非被要求,否则避免在文件中写入表情符号` GlobToolDesc = `Fast file pattern matching tool that works with any codebase size - Supports glob patterns like "**/*.js" or "src/**/*.ts" - Returns matching file paths sorted by modification time - Use this tool when you need to find files by name patterns - You can call multiple tools in a single response. It is always better to speculatively perform multiple searches in parallel if they are potentially useful. Examples: - '**/*.py' - Find all Python files - '*.txt' - Find all text files in root - '/subdir/**/*.md' - Find all markdown files under /subdir` GlobToolDescChinese = `适用于任何代码库大小的快速文件模式匹配工具 - 支持 glob 模式,如 "**/*.js" 或 "src/**/*.ts" - 返回按修改时间排序的匹配文件路径 - 当你需要按名称模式查找文件时使用此工具 - 你可以在单个响应中调用多个工具。最好同时并行执行多个可能有用的搜索 示例: - '**/*.py' - 查找所有 Python 文件 - '*.txt' - 查找根目录中的所有文本文件 - '/subdir/**/*.md' - 查找 /subdir 下的所有 markdown 文件` GrepToolDesc = ` A powerful search tool built on ripgrep Usage: - ALWAYS use Grep for search tasks. NEVER invoke 'grep' or 'rg' as a Bash command. The Grep tool has been optimized for correct permissions and access. - Supports full regex syntax (e.g., "log.*Error", "function\s+\w+") - Filter files with glob parameter (e.g., "*.js", "**/*.tsx") or type parameter (e.g., "js", "py", "rust") - Output modes: "content" shows matching lines, "files_with_matches" shows only file paths (default), "count" shows match counts - Use Task tool for open-ended searches requiring multiple rounds - Pattern syntax: Uses ripgrep (not grep) - literal braces need escaping (use 'interface\{\}' to find 'interface{}' in Go code) - Multiline matching: By default patterns match within single lines only. For cross-line patterns like 'struct \{[\s\S]*?field', use 'multiline: true'` GrepToolDescChinese = ` 基于 ripgrep 的强大搜索工具 使用方法: - 始终使用 Grep 进行搜索任务。不要将 'grep' 或 'rg' 作为 Bash 命令调用。Grep 工具已针对正确的权限和访问进行了优化 - 支持完整的正则表达式语法(例如,"log.*Error","function\s+\w+") - 使用 glob 参数(例如,"*.js","**/*.tsx")或 type 参数(例如,"js","py","rust")过滤文件 - 输出模式:"content" 显示匹配行,"files_with_matches" 仅显示文件路径(默认),"count" 显示匹配计数 - 对于需要多轮的开放式搜索,使用 Task 工具 - 模式语法:使用 ripgrep(不是 grep)- 字面大括号需要转义(使用 'interface\{\}' 在 Go 代码中查找 'interface{}') - 多行匹配:默认情况下,模式仅在单行内匹配。对于跨行模式如 'struct \{[\s\S]*?field',使用 'multiline: true'` ExecuteToolDesc = ` Executes a given command in the sandbox environment with proper handling and security measures. Before executing the command, please follow these steps: 1. Directory Verification: - If the command will create new directories or files, first use the ls tool to verify the parent directory exists and is the correct location - For example, before running "mkdir foo/bar", first use ls to check that "foo" exists and is the intended parent directory 2. Command Execution: - Always quote file paths that contain spaces with double quotes (e.g., cd "path with spaces/file.txt") - Examples of proper quoting: - cd "/Users/name/My Documents" (correct) - cd /Users/name/My Documents (incorrect - will fail) - python "/path/with spaces/script.py" (correct) - python /path/with spaces/script.py (incorrect - will fail) - After ensuring proper quoting, execute the command - Capture the output of the command Usage notes: - The command parameter is required - Commands run in an isolated sandbox environment - Returns combined stdout/stderr output with exit code - If the output is very large, it may be truncated - VERY IMPORTANT: You MUST avoid using search commands like find and grep. Instead use the grep, glob tools to search. You MUST avoid read tools like cat, head, tail, and use read_file to read files. - When issuing multiple commands, use the ';' or '&&' operator to separate them. DO NOT use newlines (newlines are ok in quoted strings) - Use '&&' when commands depend on each other (e.g., "mkdir dir && cd dir") - Use ';' only when you need to run commands sequentially but don't care if earlier commands fail - Try to maintain your current working directory throughout the session by using absolute paths and avoiding usage of cd Examples: Good examples: - execute(command="pytest /foo/bar/tests") - execute(command="python /path/to/script.py") - execute(command="npm install && npm test") Bad examples (avoid these): - execute(command="cd /foo/bar && pytest tests") # Use absolute path instead - execute(command="cat file.txt") # Use read_file tool instead - execute(command="find . -name '*.py'") # Use glob tool instead - execute(command="grep -r 'pattern' .") # Use grep tool instead ` ExecuteToolDescChinese = ` 在沙箱环境中执行给定命令,具有适当的处理和安全措施。 执行命令前,请按照以下步骤操作: 1. 目录验证: - 如果命令将创建新目录或文件,首先使用 ls 工具验证父目录是否存在且是正确的位置 - 例如,在运行 "mkdir foo/bar" 之前,首先使用 ls 检查 "foo" 是否存在且是预期的父目录 2. 命令执行: - 始终用双引号引用包含空格的文件路径(例如,cd "path with spaces/file.txt") - 正确引用的示例: - cd "/Users/name/My Documents"(正确) - cd /Users/name/My Documents(错误 - 将失败) - python "/path/with spaces/script.py"(正确) - python /path/with spaces/script.py(错误 - 将失败) - 确保正确引用后,执行命令 - 捕获命令的输出 使用说明: - command 参数是必需的 - 命令在隔离的沙箱环境中运行 - 返回合并的 stdout/stderr 输出和退出代码 - 如果输出非常大,可能会被截断 - 非常重要:你必须避免使用 find 和 grep 等搜索命令。请改用 grep、glob 工具进行搜索。你必须避免使用 cat、head、tail 等读取工具,请使用 read_file 读取文件 - 发出多个命令时,使用 ';' 或 '&&' 运算符分隔它们。不要使用换行符(引号字符串中的换行符是可以的) - 当命令相互依赖时使用 '&&'(例如,"mkdir dir && cd dir") - 仅当你需要按顺序运行命令但不关心早期命令是否失败时使用 ';' - 尝试通过使用绝对路径并避免使用 cd 来在整个会话中保持当前工作目录 示例: 好的示例: - execute(command="pytest /foo/bar/tests") - execute(command="python /path/to/script.py") - execute(command="npm install && npm test") 不好的示例(避免这些): - execute(command="cd /foo/bar && pytest tests") # 改用绝对路径 - execute(command="cat file.txt") # 改用 read_file 工具 - execute(command="find . -name '*.py'") # 改用 glob 工具 - execute(command="grep -r 'pattern' .") # 改用 grep 工具 ` ) ================================================ FILE: adk/middlewares/patchtoolcalls/patchtoolcalls.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package patchtoolcalls provides a middleware that patches dangling tool calls in the message history. package patchtoolcalls import ( "context" "fmt" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/schema" ) // Config defines the configuration options for the patch tool calls middleware. type Config struct { // PatchedContentGenerator is an optional custom function to generate the content // of patched tool messages. If not provided, a default message will be used. // // Parameters: // - ctx: the context for the operation // - toolName: the name of the tool that was called // - toolCallID: the id of the tool call // // Returns: // - string: the content to use for the patched tool message // - error: any error that occurred during generation PatchedContentGenerator func(ctx context.Context, toolName, toolCallID string) (string, error) } // New creates a new patch tool calls middleware with the given configuration. // // The middleware scans the message history before each model invocation and inserts // placeholder tool messages for any tool calls that don't have corresponding responses. func New(ctx context.Context, cfg *Config) (adk.ChatModelAgentMiddleware, error) { if cfg == nil { cfg = &Config{} } return &middleware{ BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, gen: cfg.PatchedContentGenerator, }, nil } type middleware struct { *adk.BaseChatModelAgentMiddleware gen func(ctx context.Context, toolName, toolCallID string) (string, error) } func (m *middleware) BeforeModelRewriteState(ctx context.Context, state *adk.ChatModelAgentState, mc *adk.ModelContext) (context.Context, *adk.ChatModelAgentState, error) { if len(state.Messages) == 0 { return ctx, state, nil } patched := make([]adk.Message, 0, len(state.Messages)) for i, msg := range state.Messages { patched = append(patched, msg) if msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 { continue } for _, tc := range msg.ToolCalls { if hasCorrespondingToolMessage(state.Messages[i+1:], tc.ID) { continue } toolMsg, err := m.createPatchedToolMessage(ctx, tc) if err != nil { return ctx, nil, err } patched = append(patched, toolMsg) } } nState := *state nState.Messages = patched return ctx, &nState, nil } func hasCorrespondingToolMessage(messages []adk.Message, toolCallID string) bool { for _, msg := range messages { if msg.Role == schema.Tool && msg.ToolCallID == toolCallID { return true } } return false } func (m *middleware) createPatchedToolMessage(ctx context.Context, tc schema.ToolCall) (adk.Message, error) { if m.gen != nil { content, err := m.gen(ctx, tc.Function.Name, tc.ID) if err != nil { return nil, err } return schema.ToolMessage(content, tc.ID, schema.WithToolName(tc.Function.Name)), nil } tpl := internal.SelectPrompt(internal.I18nPrompts{ English: defaultPatchedToolMessageTemplate, Chinese: defaultPatchedToolMessageTemplateChinese, }) return schema.ToolMessage(fmt.Sprintf(tpl, tc.Function.Name, tc.ID), tc.ID, schema.WithToolName(tc.Function.Name)), nil } const ( defaultPatchedToolMessageTemplate = "Tool call %s with id %s was cancelled - another message came in before it could be completed." defaultPatchedToolMessageTemplateChinese = "工具调用 %s(ID 为 %s)已被取消——在其完成之前收到了另一条消息。" ) ================================================ FILE: adk/middlewares/patchtoolcalls/patchtoolcalls_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package patchtoolcalls import ( "context" "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/schema" ) func TestPatchToolCalls(t *testing.T) { ctx := context.Background() m, err := New(ctx, nil) assert.NoError(t, err) // empty messages state := &adk.ChatModelAgentState{ Messages: nil, } _, newState, err := m.BeforeModelRewriteState(ctx, state, nil) assert.NoError(t, err) assert.Len(t, newState.Messages, 0) state = &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello"), schema.AssistantMessage("hi there", nil), }, } _, newState, err = m.BeforeModelRewriteState(ctx, state, nil) assert.NoError(t, err) assert.Len(t, newState.Messages, 2) state = &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello"), schema.AssistantMessage("", []schema.ToolCall{ {ID: "call_1", Function: schema.FunctionCall{Name: "tool_a"}}, {ID: "call_2", Function: schema.FunctionCall{Name: "tool_b"}}, }), schema.ToolMessage("result_a", "call_1", schema.WithToolName("tool_a")), }, } _, newState, err = m.BeforeModelRewriteState(ctx, state, nil) assert.NoError(t, err) patchedMsg := newState.Messages[2] assert.Equal(t, schema.Tool, patchedMsg.Role) assert.Equal(t, "call_2", patchedMsg.ToolCallID) assert.Equal(t, "tool_b", patchedMsg.ToolName) assert.Equal(t, fmt.Sprintf(defaultPatchedToolMessageTemplate, "tool_b", "call_2"), patchedMsg.Content) m, err = New(ctx, &Config{ PatchedContentGenerator: func(ctx context.Context, toolName, toolCallID string) (string, error) { return fmt.Sprintf("123 %s %s", toolName, toolCallID), nil }, }) assert.NoError(t, err) state = &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello"), schema.AssistantMessage("", []schema.ToolCall{ {ID: "call_1", Function: schema.FunctionCall{Name: "tool_a"}}, {ID: "call_2", Function: schema.FunctionCall{Name: "tool_b"}}, }), schema.ToolMessage("result_a", "call_1", schema.WithToolName("tool_a")), }, } _, newState, err = m.BeforeModelRewriteState(ctx, state, nil) assert.NoError(t, err) patchedMsg = newState.Messages[2] assert.Equal(t, schema.Tool, patchedMsg.Role) assert.Equal(t, "call_2", patchedMsg.ToolCallID) assert.Equal(t, "tool_b", patchedMsg.ToolName) assert.Equal(t, "123 tool_b call_2", patchedMsg.Content) } ================================================ FILE: adk/middlewares/plantask/backend_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "errors" "path/filepath" "strings" "sync" fspkg "github.com/cloudwego/eino/adk/filesystem" ) type inMemoryBackend struct { files map[string]string mu sync.RWMutex } func newInMemoryBackend() *inMemoryBackend { return &inMemoryBackend{ files: make(map[string]string), } } func (b *inMemoryBackend) LsInfo(ctx context.Context, req *LsInfoRequest) ([]FileInfo, error) { b.mu.RLock() defer b.mu.RUnlock() reqPath := strings.TrimSuffix(req.Path, "/") var result []FileInfo for path := range b.files { dir := filepath.Dir(path) if dir == reqPath { result = append(result, FileInfo{Path: path}) } } return result, nil } func (b *inMemoryBackend) Read(ctx context.Context, req *ReadRequest) (*fspkg.FileContent, error) { b.mu.RLock() defer b.mu.RUnlock() content, ok := b.files[req.FilePath] if !ok { return nil, errors.New("file not found") } return &fspkg.FileContent{Content: content}, nil } func (b *inMemoryBackend) Write(ctx context.Context, req *WriteRequest) error { b.mu.Lock() defer b.mu.Unlock() b.files[req.FilePath] = req.Content return nil } func (b *inMemoryBackend) Delete(ctx context.Context, req *DeleteRequest) error { b.mu.Lock() defer b.mu.Unlock() delete(b.files, req.FilePath) return nil } ================================================ FILE: adk/middlewares/plantask/plantask.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "fmt" "sync" "github.com/cloudwego/eino/adk" ) // Config is the configuration for the tool search middleware. type Config struct { Backend Backend BaseDir string } // New creates a new plantask middleware that provides task management tools for agents. // It adds TaskCreate, TaskGet, TaskUpdate, and TaskList tools to the agent's tool set, // allowing agents to create and manage structured task lists during coding sessions. func New(ctx context.Context, config *Config) (adk.ChatModelAgentMiddleware, error) { if config == nil { return nil, fmt.Errorf("config is required") } if config.Backend == nil { return nil, fmt.Errorf("backend is required") } if config.BaseDir == "" { return nil, fmt.Errorf("baseDir is required") } return &middleware{backend: config.Backend, baseDir: config.BaseDir}, nil } type middleware struct { adk.BaseChatModelAgentMiddleware backend Backend baseDir string } func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { if runCtx == nil { return ctx, runCtx, nil } nRunCtx := *runCtx lock := sync.Mutex{} nRunCtx.Tools = append(nRunCtx.Tools, newTaskCreateTool(m.backend, m.baseDir, &lock), newTaskGetTool(m.backend, m.baseDir, &lock), newTaskUpdateTool(m.backend, m.baseDir, &lock), newTaskListTool(m.backend, m.baseDir, &lock), ) return ctx, &nRunCtx, nil } ================================================ FILE: adk/middlewares/plantask/plantask_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/tool" ) func TestNew(t *testing.T) { ctx := context.Background() _, err := New(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "config is required") _, err = New(ctx, &Config{}) assert.Error(t, err) assert.Contains(t, err.Error(), "backend is required") _, err = New(ctx, &Config{Backend: newInMemoryBackend()}) assert.Error(t, err) assert.Contains(t, err.Error(), "baseDir is required") m, err := New(ctx, &Config{Backend: newInMemoryBackend(), BaseDir: "/tmp/tasks"}) assert.NoError(t, err) assert.NotNil(t, m) } func TestMiddlewareBeforeAgent(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" m, err := New(ctx, &Config{Backend: backend, BaseDir: baseDir}) assert.NoError(t, err) mw := m.(*middleware) ctx, runCtx, err := mw.BeforeAgent(ctx, nil) assert.NoError(t, err) assert.Nil(t, runCtx) runCtx = &adk.ChatModelAgentContext{ Tools: []tool.BaseTool{}, } ctx, newRunCtx, err := mw.BeforeAgent(ctx, runCtx) assert.NoError(t, err) assert.NotNil(t, newRunCtx) assert.Len(t, newRunCtx.Tools, 4) toolNames := make([]string, 0, 4) for _, t := range newRunCtx.Tools { info, _ := t.Info(ctx) toolNames = append(toolNames, info.Name) } assert.Contains(t, toolNames, "TaskCreate") assert.Contains(t, toolNames, "TaskGet") assert.Contains(t, toolNames, "TaskUpdate") assert.Contains(t, toolNames, "TaskList") } func TestIntegration(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} createTool := newTaskCreateTool(backend, baseDir, lock) getTool := newTaskGetTool(backend, baseDir, lock) updateTool := newTaskUpdateTool(backend, baseDir, lock) listTool := newTaskListTool(backend, baseDir, lock) result, err := createTool.InvokableRun(ctx, `{"subject": "Task 1", "description": "First task"}`) assert.NoError(t, err) assert.Contains(t, result, "Task #1") result, err = createTool.InvokableRun(ctx, `{"subject": "Task 2", "description": "Second task"}`) assert.NoError(t, err) assert.Contains(t, result, "Task #2") _, err = updateTool.InvokableRun(ctx, `{"taskId": "2", "addBlockedBy": ["1"]}`) assert.NoError(t, err) result, err = listTool.InvokableRun(ctx, `{}`) assert.NoError(t, err) assert.Contains(t, result, "#1 [pending] Task 1") assert.Contains(t, result, "#2 [pending] Task 2") assert.Contains(t, result, "[blocked by #1]") _, err = updateTool.InvokableRun(ctx, `{"taskId": "1", "status": "in_progress"}`) assert.NoError(t, err) result, err = getTool.InvokableRun(ctx, `{"taskId": "1"}`) assert.NoError(t, err) assert.Contains(t, result, "Status: in_progress") _, err = updateTool.InvokableRun(ctx, `{"taskId": "1", "status": "completed"}`) assert.NoError(t, err) result, err = listTool.InvokableRun(ctx, `{}`) assert.NoError(t, err) assert.Contains(t, result, "#1 [completed] Task 1") } ================================================ FILE: adk/middlewares/plantask/task.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "regexp" "github.com/cloudwego/eino/adk/middlewares/filesystem" ) var validTaskIDRegex = regexp.MustCompile(`^\d+$`) const highWatermarkFileName = ".highwatermark" type task struct { ID string `json:"id"` Subject string `json:"subject"` Description string `json:"description"` Status string `json:"status"` Blocks []string `json:"blocks"` BlockedBy []string `json:"blockedBy"` ActiveForm string `json:"activeForm,omitempty"` Owner string `json:"owner,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } type taskOut struct { Result string `json:"result"` } const ( taskStatusPending = "pending" taskStatusInProgress = "in_progress" taskStatusCompleted = "completed" taskStatusDeleted = "deleted" ) type FileInfo = filesystem.FileInfo type LsInfoRequest = filesystem.LsInfoRequest type ReadRequest = filesystem.ReadRequest type WriteRequest = filesystem.WriteRequest type DeleteRequest struct { FilePath string } // Backend defines the storage interface for task persistence. // Implementations can use local filesystem, remote storage, or any other storage backend. type Backend interface { // LsInfo lists file information in the specified directory. LsInfo(ctx context.Context, req *LsInfoRequest) ([]FileInfo, error) // Read reads the content of a file. Read(ctx context.Context, req *ReadRequest) (*filesystem.FileContent, error) // Write writes content to a file, creating it if it doesn't exist. Write(ctx context.Context, req *WriteRequest) error // Delete removes a file from storage. Delete(ctx context.Context, req *DeleteRequest) error } func isValidTaskID(taskID string) bool { return validTaskIDRegex.MatchString(taskID) } func appendUnique(slice []string, items ...string) []string { seen := make(map[string]struct{}, len(slice)) for _, s := range slice { seen[s] = struct{}{} } for _, item := range items { if _, exists := seen[item]; !exists { slice = append(slice, item) seen[item] = struct{}{} } } return slice } func hasCyclicDependency(taskMap map[string]*task, blockerID, blockedID string) bool { if blockerID == blockedID { return true } visited := make(map[string]bool) return canReach(taskMap, blockedID, blockerID, visited) } func canReach(taskMap map[string]*task, fromID, toID string, visited map[string]bool) bool { if fromID == toID { return true } if visited[fromID] { return false } visited[fromID] = true fromTask, exists := taskMap[fromID] if !exists { return false } for _, blockedID := range fromTask.Blocks { if canReach(taskMap, blockedID, toID, visited) { return true } } return false } ================================================ FILE: adk/middlewares/plantask/task_create.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "fmt" "path/filepath" "sync" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) func newTaskCreateTool(backend Backend, baseDir string, lock *sync.Mutex) *taskCreateTool { return &taskCreateTool{Backend: backend, BaseDir: baseDir, lock: lock} } type taskCreateTool struct { Backend Backend BaseDir string lock *sync.Mutex } type taskCreateArgs struct { Subject string `json:"subject"` Description string `json:"description"` ActiveForm string `json:"activeForm,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } func (t *taskCreateTool) Info(ctx context.Context) (*schema.ToolInfo, error) { desc := internal.SelectPrompt(internal.I18nPrompts{ English: taskCreateToolDesc, Chinese: taskCreateToolDescChinese, }) return &schema.ToolInfo{ Name: TaskCreateToolName, Desc: desc, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "subject": { Type: schema.String, Desc: "A brief title for the task", Required: true, }, "description": { Type: schema.String, Desc: "A detailed description of what needs to be done", Required: true, }, "activeForm": { Type: schema.String, Desc: "Present continuous form shown in spinner when in_progress (e.g., \"Running tests\")", Required: false, }, "metadata": { Type: schema.Object, Desc: "Arbitrary metadata to attach to the task", SubParams: map[string]*schema.ParameterInfo{ "propertyNames": { Type: schema.String, }, }, Required: false, }, }), }, nil } func (t *taskCreateTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { t.lock.Lock() defer t.lock.Unlock() params := &taskCreateArgs{} err := sonic.UnmarshalString(argumentsInJSON, params) if err != nil { return "", err } files, err := t.Backend.LsInfo(ctx, &LsInfoRequest{ Path: t.BaseDir, }) if err != nil { return "", fmt.Errorf("%s list files in %s failed, err: %w", TaskCreateToolName, t.BaseDir, err) } highwatermark := int64(0) for _, file := range files { fileName := filepath.Base(file.Path) if fileName == highWatermarkFileName { content, readErr := t.Backend.Read(ctx, &ReadRequest{ FilePath: file.Path, }) if readErr != nil { return "", fmt.Errorf("%s read highwatermark file %s failed, err: %w", TaskCreateToolName, file.Path, readErr) } if content.Content != "" { var val int64 if _, scanErr := fmt.Sscanf(content.Content, "%d", &val); scanErr == nil { highwatermark = val } } break } } taskID := highwatermark + 1 taskFileName := fmt.Sprintf("%d.json", taskID) for _, file := range files { fileName := filepath.Base(file.Path) if fileName == taskFileName { return "", fmt.Errorf("Task #%d already exists", taskID) } } newTask := &task{ ID: fmt.Sprintf("%d", taskID), Subject: params.Subject, Description: params.Description, Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, ActiveForm: params.ActiveForm, Metadata: params.Metadata, } taskData, err := sonic.MarshalString(newTask) if err != nil { return "", fmt.Errorf("%s marshal task #%d failed, err: %w", TaskCreateToolName, taskID, err) } // Write highwatermark file first highwatermarkPath := filepath.Join(t.BaseDir, highWatermarkFileName) err = t.Backend.Write(ctx, &WriteRequest{ FilePath: highwatermarkPath, Content: fmt.Sprintf("%d", taskID), }) if err != nil { return "", fmt.Errorf("%s update highwatermark file %s failed, err: %w", TaskCreateToolName, highwatermarkPath, err) } taskFilePath := filepath.Join(t.BaseDir, taskFileName) err = t.Backend.Write(ctx, &WriteRequest{ FilePath: taskFilePath, Content: taskData, }) if err != nil { return "", fmt.Errorf("%s create Task #%d failed, err: %w", TaskCreateToolName, taskID, err) } resp := &taskOut{ Result: fmt.Sprintf("Task #%d created successfully: %s", taskID, params.Subject), } jsonResp, err := sonic.MarshalString(resp) if err != nil { return "", fmt.Errorf("%s marshal taskOut failed, err: %w", TaskCreateToolName, err) } return jsonResp, nil } const TaskCreateToolName = "TaskCreate" const taskCreateToolDesc = `Use this tool to create a structured task list for your current coding session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user. It also helps the user understand the progress of the task and overall progress of their requests. ## When to Use This Tool Use this tool proactively in these scenarios: - Complex multi-step tasks - When a task requires 3 or more distinct steps or actions - Non-trivial and complex tasks - Tasks that require careful planning or multiple operations - Plan mode - When using plan mode, create a task list to track the work - User explicitly requests todo list - When the user directly asks you to use the todo list - User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated) - After receiving new instructions - Immediately capture user requirements as tasks - When you start working on a task - Mark it as in_progress BEFORE beginning work - After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation ## When NOT to Use This Tool Skip using this tool when: - There is only a single, straightforward task - The task is trivial and tracking it provides no organizational benefit - The task can be completed in less than 3 trivial steps - The task is purely conversational or informational NOTE that you should not use this tool if there is only one trivial task to do. In this case you are better off just doing the task directly. ## Task Fields - **subject**: A brief, actionable title in imperative form (e.g., "Fix authentication bug in login flow") - **description**: Detailed description of what needs to be done, including context and acceptance criteria - **activeForm**: Present continuous form shown in spinner when task is in_progress (e.g., "Fixing authentication bug"). This is displayed to the user while you work on the task. **IMPORTANT**: Always provide activeForm when creating tasks. The subject should be imperative ("Run tests") while activeForm should be present continuous ("Running tests"). All tasks are created with status "pending". ## Tips - Create tasks with clear, specific subjects that describe the outcome - Include enough detail in the description for another agent to understand and complete the task - After creating tasks, use TaskUpdate to set up dependencies (blocks/blockedBy) if needed - Check TaskList first to avoid creating duplicate tasks ` const taskCreateToolDescChinese = `使用此工具为当前编码会话创建结构化的任务列表。这有助于跟踪进度、组织复杂任务,并向用户展示工作的完整性。 它还帮助用户了解任务的进度和请求的整体进展。 ## 何时使用此工具 在以下场景中主动使用此工具: - 复杂的多步骤任务 - 当任务需要 3 个或更多不同的步骤或操作时 - 非简单的复杂任务 - 需要仔细规划或多个操作的任务 - 计划模式 - 使用计划模式时,创建任务列表来跟踪工作 - 用户明确要求待办列表 - 当用户直接要求使用待办列表时 - 用户提供多个任务 - 当用户提供待办事项列表时(编号或逗号分隔) - 收到新指令后 - 立即将用户需求记录为任务 - 开始处理任务时 - 在开始工作之前将其标记为 in_progress - 完成任务后 - 将其标记为已完成,并添加实施过程中发现的任何后续任务 ## 何时不使用此工具 在以下情况下跳过使用此工具: - 只有一个简单直接的任务 - 任务很简单,跟踪它没有组织上的好处 - 任务可以在少于 3 个简单步骤内完成 - 任务纯粹是对话性或信息性的 注意:如果只有一个简单任务要做,不应该使用此工具。在这种情况下,直接完成任务更好。 ## 任务字段 - **subject**:简短的、可操作的标题,使用祈使句形式(例如,"修复登录流程中的认证错误") - **description**:需要完成的工作的详细描述,包括上下文和验收标准 - **activeForm**:任务处于 in_progress 状态时在加载动画中显示的现在进行时形式(例如,"正在修复认证错误")。这会在你处理任务时显示给用户。 **重要**:创建任务时始终提供 activeForm。subject 应该是祈使句("运行测试"),而 activeForm 应该是现在进行时("正在运行测试")。所有任务创建时状态为 "pending"。 ## 提示 - 创建具有清晰、具体主题的任务,描述预期结果 - 在描述中包含足够的细节,以便其他代理能够理解并完成任务 - 创建任务后,如果需要,使用 TaskUpdate 设置依赖关系(blocks/blockedBy) - 先检查 TaskList 以避免创建重复任务 ` ================================================ FILE: adk/middlewares/plantask/task_create_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "path/filepath" "sync" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" ) func TestTaskCreateTool(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} tool := newTaskCreateTool(backend, baseDir, lock) info, err := tool.Info(ctx) assert.NoError(t, err) assert.Equal(t, TaskCreateToolName, info.Name) assert.Equal(t, taskCreateToolDesc, info.Desc) result, err := tool.InvokableRun(ctx, `{"subject": "Test Task", "description": "Test description", "activeForm": "Testing"}`) assert.NoError(t, err) assert.Equal(t, `{"result":"Task #1 created successfully: Test Task"}`, result) content, err := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) assert.NoError(t, err) var taskData task err = sonic.UnmarshalString(content.Content, &taskData) assert.NoError(t, err) assert.Equal(t, "1", taskData.ID) assert.Equal(t, "Test Task", taskData.Subject) assert.Equal(t, "Test description", taskData.Description) assert.Equal(t, taskStatusPending, taskData.Status) assert.Equal(t, "Testing", taskData.ActiveForm) hwContent, err := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, highWatermarkFileName)}) assert.NoError(t, err) assert.Equal(t, "1", hwContent.Content) result, err = tool.InvokableRun(ctx, `{"subject": "Second Task", "description": "Second description"}`) assert.NoError(t, err) assert.Equal(t, `{"result":"Task #2 created successfully: Second Task"}`, result) hwContent, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, highWatermarkFileName)}) assert.NoError(t, err) assert.Equal(t, "2", hwContent.Content) } func TestTaskCreateToolWithMetadata(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} tool := newTaskCreateTool(backend, baseDir, lock) result, err := tool.InvokableRun(ctx, `{"subject": "Task with metadata", "description": "Has metadata", "metadata": {"key1": "value1", "key2": "value2"}}`) assert.NoError(t, err) assert.Contains(t, result, "Task #1 created successfully") content, err := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) assert.NoError(t, err) var taskData task err = sonic.UnmarshalString(content.Content, &taskData) assert.NoError(t, err) assert.Equal(t, "value1", taskData.Metadata["key1"]) assert.Equal(t, "value2", taskData.Metadata["key2"]) } ================================================ FILE: adk/middlewares/plantask/task_get.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "fmt" "path/filepath" "strings" "sync" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) func newTaskGetTool(backend Backend, baseDir string, lock *sync.Mutex) *taskGetTool { return &taskGetTool{Backend: backend, BaseDir: baseDir, lock: lock} } type taskGetTool struct { Backend Backend BaseDir string lock *sync.Mutex } func (t *taskGetTool) Info(ctx context.Context) (*schema.ToolInfo, error) { desc := internal.SelectPrompt(internal.I18nPrompts{ English: taskGetToolDesc, Chinese: taskGetToolDescChinese, }) return &schema.ToolInfo{ Name: TaskGetToolName, Desc: desc, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "taskId": { Type: schema.String, Desc: "The ID of the task to retrieve", Required: true, }, }), }, nil } type taskGetArgs struct { TaskID string `json:"taskId"` } func (t *taskGetTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { t.lock.Lock() defer t.lock.Unlock() params := &taskGetArgs{} err := sonic.UnmarshalString(argumentsInJSON, params) if err != nil { return "", err } if !isValidTaskID(params.TaskID) { return "", fmt.Errorf("%s validate task ID failed, err: invalid format: %s", TaskGetToolName, params.TaskID) } taskFileName := fmt.Sprintf("%s.json", params.TaskID) taskFilePath := filepath.Join(t.BaseDir, taskFileName) content, err := t.Backend.Read(ctx, &ReadRequest{ FilePath: taskFilePath, }) if err != nil { return "", fmt.Errorf("%s get Task #%s failed, err: %w", TaskGetToolName, params.TaskID, err) } taskData := &task{} err = sonic.UnmarshalString(content.Content, taskData) if err != nil { return "", fmt.Errorf("%s get Task #%s failed, err: %w", TaskGetToolName, params.TaskID, err) } var result strings.Builder result.WriteString(fmt.Sprintf("Task #%s: %s\n", taskData.ID, taskData.Subject)) result.WriteString(fmt.Sprintf("Status: %s\n", taskData.Status)) result.WriteString(fmt.Sprintf("Description: %s\n", taskData.Description)) if len(taskData.BlockedBy) > 0 { blockedByIDs := make([]string, len(taskData.BlockedBy)) for i, id := range taskData.BlockedBy { blockedByIDs[i] = "#" + id } result.WriteString(fmt.Sprintf("Blocked by: %s\n", strings.Join(blockedByIDs, ", "))) } if len(taskData.Blocks) > 0 { blocksIDs := make([]string, len(taskData.Blocks)) for i, id := range taskData.Blocks { blocksIDs[i] = "#" + id } result.WriteString(fmt.Sprintf("Blocks: %s\n", strings.Join(blocksIDs, ", "))) } if taskData.Owner != "" { result.WriteString(fmt.Sprintf("Owner: %s\n", taskData.Owner)) } resp := &taskOut{ Result: result.String(), } jsonResp, err := sonic.MarshalString(resp) if err != nil { return "", fmt.Errorf("%s marshal taskOut failed, err: %w", TaskGetToolName, err) } return jsonResp, nil } const TaskGetToolName = "TaskGet" const taskGetToolDesc = `Use this tool to retrieve a task by its ID from the task list. ## When to Use This Tool - When you need the full description and context before starting work on a task - To understand task dependencies (what it blocks, what blocks it) - After being assigned a task, to get complete requirements ## Output Returns full task details: - **subject**: Task title - **description**: Detailed requirements and context - **status**: 'pending', 'in_progress', or 'completed' - **blocks**: Tasks waiting on this one to complete - **blockedBy**: Tasks that must complete before this one can start ## Tips - After fetching a task, verify its blockedBy list is empty before beginning work. - Use TaskList to see all tasks in summary form. ` const taskGetToolDescChinese = `使用此工具通过任务 ID 从任务列表中获取任务。 ## 何时使用此工具 - 当你需要在开始处理任务之前获取完整的描述和上下文时 - 了解任务依赖关系(它阻塞什么,什么阻塞它) - 被分配任务后,获取完整的需求 ## 输出 返回完整的任务详情: - **subject**:任务标题 - **description**:详细的需求和上下文 - **status**:'pending'、'in_progress' 或 'completed' - **blocks**:等待此任务完成的任务 - **blockedBy**:必须在此任务开始之前完成的任务 ## 提示 - 获取任务后,在开始工作之前验证其 blockedBy 列表是否为空。 - 使用 TaskList 查看所有任务的摘要形式。 ` ================================================ FILE: adk/middlewares/plantask/task_get_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "path/filepath" "sync" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" ) func TestTaskGetTool(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} taskData := &task{ ID: "1", Subject: "Test Task", Description: "Test description", Status: taskStatusPending, Blocks: []string{"2", "3"}, BlockedBy: []string{"4"}, } taskJSON, _ := sonic.MarshalString(taskData) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: taskJSON}) tool := newTaskGetTool(backend, baseDir, lock) info, err := tool.Info(ctx) assert.NoError(t, err) assert.Equal(t, TaskGetToolName, info.Name) assert.Equal(t, taskGetToolDesc, info.Desc) result, err := tool.InvokableRun(ctx, `{"taskId": "1"}`) assert.NoError(t, err) assert.Contains(t, result, "Task #1: Test Task") assert.Contains(t, result, "Status: "+taskStatusPending) assert.Contains(t, result, "Description: Test description") assert.Contains(t, result, "Blocked by: #4") assert.Contains(t, result, "Blocks: #2, #3") _, err = tool.InvokableRun(ctx, `{"taskId": "999"}`) assert.Error(t, err) } func TestTaskGetToolInvalidTaskID(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} tool := newTaskGetTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "../../../etc/passwd"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "validate task ID failed") _, err = tool.InvokableRun(ctx, `{"taskId": "abc"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "validate task ID failed") } ================================================ FILE: adk/middlewares/plantask/task_list.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "fmt" "path/filepath" "sort" "strings" "sync" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) func newTaskListTool(backend Backend, baseDir string, lock *sync.Mutex) *taskListTool { return &taskListTool{Backend: backend, BaseDir: baseDir, lock: lock} } type taskListTool struct { Backend Backend BaseDir string lock *sync.Mutex } func (t *taskListTool) Info(ctx context.Context) (*schema.ToolInfo, error) { desc := internal.SelectPrompt(internal.I18nPrompts{ English: taskListToolDesc, Chinese: taskListToolDescChinese, }) return &schema.ToolInfo{ Name: TaskListToolName, Desc: desc, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{}), }, nil } func listTasks(ctx context.Context, backend Backend, baseDir string) ([]*task, error) { files, err := backend.LsInfo(ctx, &LsInfoRequest{ Path: baseDir, }) if err != nil { return nil, fmt.Errorf("%s list files in %s failed, err: %w", TaskListToolName, baseDir, err) } var tasks []*task for _, file := range files { fileName := filepath.Base(file.Path) if !strings.HasSuffix(fileName, ".json") { continue } taskID := strings.TrimSuffix(fileName, ".json") if !isValidTaskID(taskID) { continue } content, err := backend.Read(ctx, &ReadRequest{ FilePath: file.Path, }) if err != nil { return nil, fmt.Errorf("%s read task file %s failed, err: %w", TaskListToolName, file.Path, err) } taskData := &task{} err = sonic.UnmarshalString(content.Content, taskData) if err != nil { return nil, fmt.Errorf("%s parse task file %s failed, err: %w", TaskListToolName, file.Path, err) } tasks = append(tasks, taskData) } // sort tasks by ID sort.Slice(tasks, func(i, j int) bool { return tasks[i].ID < tasks[j].ID }) return tasks, nil } func (t *taskListTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { t.lock.Lock() defer t.lock.Unlock() tasks, err := listTasks(ctx, t.Backend, t.BaseDir) if err != nil { return "", err } if len(tasks) == 0 { resp := &taskOut{ Result: "No tasks found.", } jsonResp, marshalErr := sonic.MarshalString(resp) if marshalErr != nil { return "", fmt.Errorf("%s marshal taskOut failed, err: %w", TaskListToolName, marshalErr) } return jsonResp, nil } var result strings.Builder for i, taskData := range tasks { if i > 0 { result.WriteString("\n") } result.WriteString(fmt.Sprintf("#%s [%s] %s", taskData.ID, taskData.Status, taskData.Subject)) if taskData.Owner != "" { result.WriteString(fmt.Sprintf(" [owner: %s]", taskData.Owner)) } if len(taskData.BlockedBy) > 0 { blockedByIDs := make([]string, len(taskData.BlockedBy)) for j, id := range taskData.BlockedBy { blockedByIDs[j] = "#" + id } result.WriteString(fmt.Sprintf(" [blocked by %s]", strings.Join(blockedByIDs, ", "))) } } resp := &taskOut{ Result: result.String(), } jsonResp, err := sonic.MarshalString(resp) if err != nil { return "", fmt.Errorf("%s marshal taskOut failed, err: %w", TaskListToolName, err) } return jsonResp, nil } const TaskListToolName = "TaskList" const taskListToolDesc = `Use this tool to list all tasks in the task list. ## When to Use This Tool - To see what tasks are available to work on (status: 'pending', no owner, not blocked) - To check overall progress on the project - To find tasks that are blocked and need dependencies resolved - After completing a task, to check for newly unblocked work or claim the next available task - **Prefer working on tasks in ID order** (lowest ID first) when multiple tasks are available, as earlier tasks often set up context for later ones ## Output Returns a summary of each task: - **id**: Task identifier (use with TaskGet, TaskUpdate) - **subject**: Brief description of the task - **status**: 'pending', 'in_progress', or 'completed' - **owner**: Agent ID if assigned, empty if available - **blockedBy**: List of open task IDs that must be resolved first (tasks with blockedBy cannot be claimed until dependencies resolve) Use TaskGet with a specific task ID to view full details including description and comments. ` const taskListToolDescChinese = `使用此工具列出任务列表中的所有任务。 ## 何时使用此工具 - 查看可以处理的任务(状态:'pending',无所有者,未被阻塞) - 检查项目的整体进度 - 查找被阻塞且需要解决依赖关系的任务 - 完成任务后,检查新解除阻塞的工作或认领下一个可用任务 - **优先按 ID 顺序处理任务**(最小 ID 优先),当有多个任务可用时,因为较早的任务通常为后续任务建立上下文 ## 输出 返回每个任务的摘要: - **id**:任务标识符(与 TaskGet、TaskUpdate 一起使用) - **subject**:任务的简要描述 - **status**:'pending'、'in_progress' 或 'completed' - **owner**:如果已分配则为代理 ID,如果可用则为空 - **blockedBy**:必须首先解决的开放任务 ID 列表(具有 blockedBy 的任务在依赖关系解决之前无法被认领) 使用 TaskGet 配合特定任务 ID 查看完整详情,包括描述和评论。 ` ================================================ FILE: adk/middlewares/plantask/task_list_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "path/filepath" "sync" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" ) func TestTaskListTool(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} tool := newTaskListTool(backend, baseDir, lock) info, err := tool.Info(ctx) assert.NoError(t, err) assert.Equal(t, TaskListToolName, info.Name) assert.Equal(t, taskListToolDesc, info.Desc) result, err := tool.InvokableRun(ctx, `{}`) assert.NoError(t, err) assert.Equal(t, `{"result":"No tasks found."}`, result) task1 := &task{ID: "1", Subject: "Task 1", Status: taskStatusPending, BlockedBy: []string{"2"}} task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ID: "2", Subject: "Task 2", Status: taskStatusInProgress, Owner: "agent1"} task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) result, err = tool.InvokableRun(ctx, `{}`) assert.NoError(t, err) assert.Contains(t, result, "#1 ["+taskStatusPending+"] Task 1") assert.Contains(t, result, "[blocked by #2]") assert.Contains(t, result, "#2 ["+taskStatusInProgress+"] Task 2") assert.Contains(t, result, "[owner: agent1]") } ================================================ FILE: adk/middlewares/plantask/task_update.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "fmt" "path/filepath" "strings" "sync" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) func newTaskUpdateTool(backend Backend, baseDir string, lock *sync.Mutex) *taskUpdateTool { return &taskUpdateTool{Backend: backend, BaseDir: baseDir, lock: lock} } type taskUpdateTool struct { Backend Backend BaseDir string lock *sync.Mutex } type taskUpdateArgs struct { TaskID string `json:"taskId"` Subject string `json:"subject,omitempty"` Description string `json:"description,omitempty"` ActiveForm string `json:"activeForm,omitempty"` Status string `json:"status,omitempty"` AddBlocks []string `json:"addBlocks,omitempty"` AddBlockedBy []string `json:"addBlockedBy,omitempty"` Owner string `json:"owner,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } func (t *taskUpdateTool) Info(ctx context.Context) (*schema.ToolInfo, error) { desc := internal.SelectPrompt(internal.I18nPrompts{ English: taskUpdateToolDesc, Chinese: taskUpdateToolDescChinese, }) return &schema.ToolInfo{ Name: TaskUpdateToolName, Desc: desc, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "taskId": { Type: schema.String, Desc: "The ID of the task to update", Required: true, }, "subject": { Type: schema.String, Desc: "New subject for the task", Required: false, }, "description": { Type: schema.String, Desc: "New description for the task", Required: false, }, "activeForm": { Type: schema.String, Desc: "Present continuous form shown in spinner when in_progress (e.g., \"Running tests\")", Required: false, }, "status": { Type: schema.String, Desc: "New status for the task: 'pending', 'in_progress', 'completed', or 'deleted'", Enum: []string{"pending", "in_progress", "completed", "deleted"}, Required: false, }, "addBlocks": { Type: schema.Array, Desc: "Task IDs that this task blocks", ElemInfo: &schema.ParameterInfo{Type: schema.String}, Required: false, }, "addBlockedBy": { Type: schema.Array, Desc: "Task IDs that block this task", ElemInfo: &schema.ParameterInfo{Type: schema.String}, Required: false, }, "owner": { Type: schema.String, Desc: "New owner for the task", Required: false, }, "metadata": { Type: schema.Object, Desc: "Metadata keys to merge into the task. Set a key to null to delete it.", SubParams: map[string]*schema.ParameterInfo{ "propertyNames": { Type: schema.String, }, }, Required: false, }, }), }, nil } func (t *taskUpdateTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { t.lock.Lock() defer t.lock.Unlock() params := &taskUpdateArgs{} err := sonic.UnmarshalString(argumentsInJSON, params) if err != nil { return "", err } if !isValidTaskID(params.TaskID) { return "", fmt.Errorf("%s validate task ID failed, err: invalid format: %s", TaskUpdateToolName, params.TaskID) } taskFileName := fmt.Sprintf("%s.json", params.TaskID) taskFilePath := filepath.Join(t.BaseDir, taskFileName) if params.Status == taskStatusDeleted { if removeErr := t.removeTaskFromDependencies(ctx, params.TaskID); removeErr != nil { return "", fmt.Errorf("%s remove Task #%s from dependencies failed, err: %w", TaskUpdateToolName, params.TaskID, removeErr) } err = t.Backend.Delete(ctx, &DeleteRequest{ FilePath: taskFilePath, }) if err != nil { return "", fmt.Errorf("%s delete Task #%s failed, err: %w", TaskUpdateToolName, params.TaskID, err) } resp := &taskOut{ Result: fmt.Sprintf("Updated task #%s deleted", params.TaskID), } jsonResp, marshalErr := sonic.MarshalString(resp) if marshalErr != nil { return "", fmt.Errorf("%s marshal taskOut failed, err: %w", TaskUpdateToolName, marshalErr) } return jsonResp, nil } content, err := t.Backend.Read(ctx, &ReadRequest{ FilePath: taskFilePath, }) if err != nil { return "", fmt.Errorf("%s read Task #%s failed, err: %w", TaskUpdateToolName, params.TaskID, err) } taskData := &task{} err = sonic.UnmarshalString(content.Content, taskData) if err != nil { return "", fmt.Errorf("%s parse Task #%s failed, err: %w", TaskUpdateToolName, params.TaskID, err) } var updatedFields []string if params.Subject != "" { taskData.Subject = params.Subject updatedFields = append(updatedFields, "subject") } if params.Description != "" { taskData.Description = params.Description updatedFields = append(updatedFields, "description") } if params.ActiveForm != "" { taskData.ActiveForm = params.ActiveForm updatedFields = append(updatedFields, "activeForm") } if params.Status != "" { taskData.Status = params.Status updatedFields = append(updatedFields, "status") } if len(params.AddBlocks) > 0 || len(params.AddBlockedBy) > 0 { tasks, listErr := listTasks(ctx, t.Backend, t.BaseDir) if listErr != nil { return "", fmt.Errorf("%s list tasks failed, err: %w", TaskUpdateToolName, listErr) } taskMap := make(map[string]*task, len(tasks)) for _, tk := range tasks { taskMap[tk.ID] = tk } if len(params.AddBlocks) > 0 { for _, blockedTaskID := range params.AddBlocks { if !isValidTaskID(blockedTaskID) { return "", fmt.Errorf("%s validate blocked task ID failed, err: invalid format: %s", TaskUpdateToolName, blockedTaskID) } if hasCyclicDependency(taskMap, params.TaskID, blockedTaskID) { return "", fmt.Errorf("%s adding Task #%s to blocks of Task #%s would create a cyclic dependency", TaskUpdateToolName, blockedTaskID, params.TaskID) } } for _, blockedTaskID := range params.AddBlocks { if addErr := t.addBlockedByToTask(ctx, blockedTaskID, params.TaskID); addErr != nil { return "", fmt.Errorf("%s update Task #%s blocks failed, err: %w", TaskUpdateToolName, params.TaskID, addErr) } } taskData.Blocks = appendUnique(taskData.Blocks, params.AddBlocks...) updatedFields = append(updatedFields, "blocks") } if len(params.AddBlockedBy) > 0 { for _, blockingTaskID := range params.AddBlockedBy { if !isValidTaskID(blockingTaskID) { return "", fmt.Errorf("%s validate blocking task ID failed, err: invalid format: %s", TaskUpdateToolName, blockingTaskID) } if hasCyclicDependency(taskMap, blockingTaskID, params.TaskID) { return "", fmt.Errorf("%s adding Task #%s to blockedBy of Task #%s would create a cyclic dependency", TaskUpdateToolName, blockingTaskID, params.TaskID) } } for _, blockingTaskID := range params.AddBlockedBy { if addErr := t.addBlocksToTask(ctx, blockingTaskID, params.TaskID); addErr != nil { return "", fmt.Errorf("%s update Task #%s blockedBy failed, err: %w", TaskUpdateToolName, params.TaskID, addErr) } } taskData.BlockedBy = appendUnique(taskData.BlockedBy, params.AddBlockedBy...) updatedFields = append(updatedFields, "blockedBy") } } if params.Owner != "" { taskData.Owner = params.Owner updatedFields = append(updatedFields, "owner") } if params.Metadata != nil { if taskData.Metadata == nil { taskData.Metadata = make(map[string]any) } for k, v := range params.Metadata { if v == nil { delete(taskData.Metadata, k) } else { taskData.Metadata[k] = v } } updatedFields = append(updatedFields, "metadata") } updatedContent, err := sonic.MarshalString(taskData) if err != nil { return "", fmt.Errorf("%s marshal Task #%s failed, err: %w", TaskUpdateToolName, params.TaskID, err) } err = t.Backend.Write(ctx, &WriteRequest{ FilePath: taskFilePath, Content: updatedContent, }) if err != nil { return "", fmt.Errorf("%s write Task #%s failed, err: %w", TaskUpdateToolName, params.TaskID, err) } if params.Status == taskStatusCompleted { if checkErr := t.checkIfNeedDeleteAllTasks(ctx); checkErr != nil { return "", fmt.Errorf("%s check and delete all tasks failed, err: %w", TaskUpdateToolName, checkErr) } } resp := &taskOut{ Result: fmt.Sprintf("Updated task #%s %s", params.TaskID, strings.Join(updatedFields, ", ")), } jsonResp, err := sonic.MarshalString(resp) if err != nil { return "", fmt.Errorf("%s marshal taskOut failed, err: %w", TaskUpdateToolName, err) } return jsonResp, nil } func (t *taskUpdateTool) removeTaskFromDependencies(ctx context.Context, deletedTaskID string) error { tasks, err := listTasks(ctx, t.Backend, t.BaseDir) if err != nil { return err } for _, taskData := range tasks { if taskData.ID == deletedTaskID { continue } modified := false newBlocks := make([]string, 0, len(taskData.Blocks)) for _, id := range taskData.Blocks { if id != deletedTaskID { newBlocks = append(newBlocks, id) } else { modified = true } } newBlockedBy := make([]string, 0, len(taskData.BlockedBy)) for _, id := range taskData.BlockedBy { if id != deletedTaskID { newBlockedBy = append(newBlockedBy, id) } else { modified = true } } if modified { taskData.Blocks = newBlocks taskData.BlockedBy = newBlockedBy updatedContent, err := sonic.MarshalString(taskData) if err != nil { return fmt.Errorf("failed to marshal task #%s: %w", taskData.ID, err) } taskFilePath := filepath.Join(t.BaseDir, fmt.Sprintf("%s.json", taskData.ID)) if err := t.Backend.Write(ctx, &WriteRequest{FilePath: taskFilePath, Content: updatedContent}); err != nil { return fmt.Errorf("failed to write task #%s: %w", taskData.ID, err) } } } return nil } func (t *taskUpdateTool) addBlockedByToTask(ctx context.Context, targetTaskID, blockerTaskID string) error { taskFilePath := filepath.Join(t.BaseDir, fmt.Sprintf("%s.json", targetTaskID)) content, err := t.Backend.Read(ctx, &ReadRequest{FilePath: taskFilePath}) if err != nil { return fmt.Errorf("failed to read task #%s for updating blockedBy: %w", targetTaskID, err) } targetTask := &task{} if unmarshalErr := sonic.UnmarshalString(content.Content, targetTask); unmarshalErr != nil { return fmt.Errorf("failed to parse task #%s: %w", targetTaskID, unmarshalErr) } targetTask.BlockedBy = appendUnique(targetTask.BlockedBy, blockerTaskID) updatedContent, err := sonic.MarshalString(targetTask) if err != nil { return fmt.Errorf("failed to marshal task #%s: %w", targetTaskID, err) } if err := t.Backend.Write(ctx, &WriteRequest{FilePath: taskFilePath, Content: updatedContent}); err != nil { return fmt.Errorf("failed to write task #%s: %w", targetTaskID, err) } return nil } func (t *taskUpdateTool) addBlocksToTask(ctx context.Context, targetTaskID, blockedTaskID string) error { taskFilePath := filepath.Join(t.BaseDir, fmt.Sprintf("%s.json", targetTaskID)) content, err := t.Backend.Read(ctx, &ReadRequest{FilePath: taskFilePath}) if err != nil { return fmt.Errorf("failed to read task #%s for updating blocks: %w", targetTaskID, err) } targetTask := &task{} if unmarshalErr := sonic.UnmarshalString(content.Content, targetTask); unmarshalErr != nil { return fmt.Errorf("failed to parse task #%s: %w", targetTaskID, unmarshalErr) } targetTask.Blocks = appendUnique(targetTask.Blocks, blockedTaskID) updatedContent, err := sonic.MarshalString(targetTask) if err != nil { return fmt.Errorf("failed to marshal task #%s: %w", targetTaskID, err) } if err := t.Backend.Write(ctx, &WriteRequest{FilePath: taskFilePath, Content: updatedContent}); err != nil { return fmt.Errorf("failed to write task #%s: %w", targetTaskID, err) } return nil } // checkIfNeedDeleteAllTasks checks if all tasks are completed, if so, it deletes all tasks func (t *taskUpdateTool) checkIfNeedDeleteAllTasks(ctx context.Context) error { tasks, err := listTasks(ctx, t.Backend, t.BaseDir) if err != nil { return err } for _, task := range tasks { if task.Status != taskStatusCompleted { return nil } } for _, task := range tasks { err := t.Backend.Delete(ctx, &DeleteRequest{ FilePath: filepath.Join(t.BaseDir, task.ID+".json"), }) if err != nil { return err } } return nil } const TaskUpdateToolName = "TaskUpdate" const taskUpdateToolDesc = `Use this tool to update a task in the task list. ## When to Use This Tool **Mark tasks as resolved:** - When you have completed the work described in a task - When a task is no longer needed or has been superseded - IMPORTANT: Always mark your assigned tasks as resolved when you finish them - After resolving, call TaskList to find your next task - ONLY mark a task as completed when you have FULLY accomplished it - If you encounter errors, blockers, or cannot finish, keep the task as in_progress - When blocked, create a new task describing what needs to be resolved - Never mark a task as completed if: - Tests are failing - Implementation is partial - You encountered unresolved errors - You couldn't find necessary files or dependencies **Delete tasks:** - When a task is no longer relevant or was created in error - Setting status to ` + "`deleted`" + ` permanently removes the task **Update task details:** - When requirements change or become clearer - When establishing dependencies between tasks ## Fields You Can Update - **status**: The task status (see Status Workflow below) - **subject**: Change the task title (imperative form, e.g., "Run tests") - **description**: Change the task description - **activeForm**: Present continuous form shown in spinner when in_progress (e.g., "Running tests") - **owner**: Change the task owner (agent name) - **metadata**: Merge metadata keys into the task (set a key to null to delete it) - **addBlocks**: Mark tasks that cannot start until this one completes - **addBlockedBy**: Mark tasks that must complete before this one can start ## Status Workflow Status progresses: ` + "`pending`" + ` → ` + "`in_progress`" + ` → ` + "`completed`" + ` Use ` + "`deleted`" + ` to permanently remove a task. ## Staleness Make sure to read a task's latest state using ` + "`TaskGet`" + ` before updating it. ## Examples Mark task as in progress when starting work: ` + "```json" + ` {"taskId": "1", "status": "in_progress"} ` + "```" + ` Mark task as completed after finishing work: ` + "```json" + ` {"taskId": "1", "status": "completed"} ` + "```" + ` Delete a task: ` + "```json" + ` {"taskId": "1", "status": "deleted"} ` + "```" + ` Claim a task by setting owner: ` + "```json" + ` {"taskId": "1", "owner": "my-name"} ` + "```" + ` Set up task dependencies: ` + "```json" + ` {"taskId": "2", "addBlockedBy": ["1"]} ` + "```" + ` ` const taskUpdateToolDescChinese = `使用此工具更新任务列表中的任务。 ## 何时使用此工具 **将任务标记为已完成:** - 当你完成了任务中描述的工作时 - 当任务不再需要或已被取代时 - 重要:完成分配给你的任务后,务必将其标记为已完成 - 完成后,调用 TaskList 查找下一个任务 - 只有在完全完成任务时才将其标记为已完成 - 如果遇到错误、阻塞或无法完成,请保持任务为 in_progress 状态 - 当被阻塞时,创建一个新任务描述需要解决的问题 - 在以下情况下不要将任务标记为已完成: - 测试失败 - 实现不完整 - 遇到未解决的错误 - 找不到必要的文件或依赖项 **删除任务:** - 当任务不再相关或创建错误时 - 将状态设置为 ` + "`deleted`" + ` 会永久删除任务 **更新任务详情:** - 当需求变更或变得更清晰时 - 当建立任务之间的依赖关系时 ## 可更新的字段 - **status**:任务状态(参见下方状态流程) - **subject**:更改任务标题(使用祈使句形式,例如"运行测试") - **description**:更改任务描述 - **activeForm**:in_progress 状态时在加载动画中显示的现在进行时形式(例如"正在运行测试") - **owner**:更改任务所有者(代理名称) - **metadata**:将元数据键合并到任务中(将键设置为 null 可删除它) - **addBlocks**:标记在此任务完成之前无法开始的任务 - **addBlockedBy**:标记必须在此任务开始之前完成的任务 ## 状态流程 状态进展:` + "`pending`" + ` → ` + "`in_progress`" + ` → ` + "`completed`" + ` 使用 ` + "`deleted`" + ` 永久删除任务。 ## 过期性 更新任务前,请确保使用 ` + "`TaskGet`" + ` 读取任务的最新状态。 ## 示例 开始工作时将任务标记为进行中: ` + "```json" + ` {"taskId": "1", "status": "in_progress"} ` + "```" + ` 完成工作后将任务标记为已完成: ` + "```json" + ` {"taskId": "1", "status": "completed"} ` + "```" + ` 删除任务: ` + "```json" + ` {"taskId": "1", "status": "deleted"} ` + "```" + ` 通过设置 owner 认领任务: ` + "```json" + ` {"taskId": "1", "owner": "my-name"} ` + "```" + ` 设置任务依赖关系: ` + "```json" + ` {"taskId": "2", "addBlockedBy": ["1"]} ` + "```" + ` ` ================================================ FILE: adk/middlewares/plantask/task_update_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package plantask import ( "context" "path/filepath" "sync" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" ) func TestTaskUpdateTool(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} taskData := &task{ ID: "1", Subject: "Original Subject", Description: "Original description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } taskJSON, _ := sonic.MarshalString(taskData) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: taskJSON}) tool := newTaskUpdateTool(backend, baseDir, lock) info, err := tool.Info(ctx) assert.NoError(t, err) assert.Equal(t, TaskUpdateToolName, info.Name) assert.Equal(t, taskUpdateToolDesc, info.Desc) result, err := tool.InvokableRun(ctx, `{"taskId": "1", "status": "in_progress"}`) assert.NoError(t, err) assert.Contains(t, result, "Updated task #1") assert.Contains(t, result, "status") content, err := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) assert.NoError(t, err) var updated task _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, taskStatusInProgress, updated.Status) result, err = tool.InvokableRun(ctx, `{"taskId": "1", "subject": "New Subject", "description": "New description"}`) assert.NoError(t, err) assert.Contains(t, result, "subject") assert.Contains(t, result, "description") content, _ = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, "New Subject", updated.Subject) assert.Equal(t, "New description", updated.Description) } func TestTaskUpdateToolOwnerAndMetadata(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} taskData := &task{ ID: "1", Subject: "Test Task", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } taskJSON, _ := sonic.MarshalString(taskData) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: taskJSON}) tool := newTaskUpdateTool(backend, baseDir, lock) result, err := tool.InvokableRun(ctx, `{"taskId": "1", "owner": "agent1"}`) assert.NoError(t, err) assert.Contains(t, result, "owner") content, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updated task _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, "agent1", updated.Owner) result, err = tool.InvokableRun(ctx, `{"taskId": "1", "metadata": {"key1": "value1", "key2": "value2"}}`) assert.NoError(t, err) assert.Contains(t, result, "metadata") content, _ = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, "value1", updated.Metadata["key1"]) assert.Equal(t, "value2", updated.Metadata["key2"]) result, err = tool.InvokableRun(ctx, `{"taskId": "1", "metadata": {"key1": null, "key3": "value3"}}`) assert.NoError(t, err) content, _ = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updated2 task _ = sonic.UnmarshalString(content.Content, &updated2) _, key1Exists := updated2.Metadata["key1"] assert.False(t, key1Exists) assert.Equal(t, "value2", updated2.Metadata["key2"]) assert.Equal(t, "value3", updated2.Metadata["key3"]) } func TestTaskUpdateToolBlocks(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Test Task", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) task3 := &task{ ID: "3", Subject: "Task 3", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task3JSON, _ := sonic.MarshalString(task3) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "3.json"), Content: task3JSON}) task4 := &task{ ID: "4", Subject: "Task 4", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task4JSON, _ := sonic.MarshalString(task4) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "4.json"), Content: task4JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) result, err := tool.InvokableRun(ctx, `{"taskId": "1", "addBlocks": ["2", "3"]}`) assert.NoError(t, err) assert.Contains(t, result, "blocks") content, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updated task _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, []string{"2", "3"}, updated.Blocks) result, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlockedBy": ["4"]}`) assert.NoError(t, err) assert.Contains(t, result, "blockedBy") content, _ = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, []string{"4"}, updated.BlockedBy) } func TestTaskUpdateToolDelete(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} taskData := &task{ ID: "1", Subject: "Test Task", Description: "Test description", Status: taskStatusPending, } taskJSON, _ := sonic.MarshalString(taskData) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: taskJSON}) tool := newTaskUpdateTool(backend, baseDir, lock) result, err := tool.InvokableRun(ctx, `{"taskId": "1", "status": "deleted"}`) assert.NoError(t, err) assert.Contains(t, result, "deleted") _, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) assert.Error(t, err) } func TestTaskUpdateToolInvalidTaskID(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "../../../etc/passwd", "status": "in_progress"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "validate task ID failed") _, err = tool.InvokableRun(ctx, `{"taskId": "abc", "status": "in_progress"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "validate task ID failed") _, err = tool.InvokableRun(ctx, `{"taskId": "1.5", "status": "in_progress"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "validate task ID failed") task1 := &task{ ID: "1", Subject: "Task 1", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlocks": ["invalid"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "validate blocked task ID failed") _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlockedBy": ["invalid"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "validate blocking task ID failed") } func TestTaskUpdateToolBlocksDeduplication(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{"1"}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) task3 := &task{ ID: "3", Subject: "Task 3", Description: "Test description", Status: taskStatusPending, Blocks: []string{"1"}, BlockedBy: []string{}, } task3JSON, _ := sonic.MarshalString(task3) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "3.json"), Content: task3JSON}) task4 := &task{ ID: "4", Subject: "Task 4", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task4JSON, _ := sonic.MarshalString(task4) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "4.json"), Content: task4JSON}) task5 := &task{ ID: "5", Subject: "Task 5", Description: "Test description", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task5JSON, _ := sonic.MarshalString(task5) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "5.json"), Content: task5JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "1", "addBlocks": ["2", "4", "4"]}`) assert.NoError(t, err) content, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updated task _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, []string{"2", "4"}, updated.Blocks) _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlockedBy": ["3", "5", "5"]}`) assert.NoError(t, err) content, _ = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) _ = sonic.UnmarshalString(content.Content, &updated) assert.Equal(t, []string{"3", "5"}, updated.BlockedBy) } func TestTaskUpdateToolBidirectionalBlocks(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "First task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Second task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) task3 := &task{ ID: "3", Subject: "Task 3", Description: "Third task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task3JSON, _ := sonic.MarshalString(task3) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "3.json"), Content: task3JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "1", "addBlocks": ["2", "3"]}`) assert.NoError(t, err) content1, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updatedTask1 task _ = sonic.UnmarshalString(content1.Content, &updatedTask1) assert.Equal(t, []string{"2", "3"}, updatedTask1.Blocks) assert.Empty(t, updatedTask1.BlockedBy) content2, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "2.json")}) var updatedTask2 task _ = sonic.UnmarshalString(content2.Content, &updatedTask2) assert.Empty(t, updatedTask2.Blocks) assert.Equal(t, []string{"1"}, updatedTask2.BlockedBy) content3, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "3.json")}) var updatedTask3 task _ = sonic.UnmarshalString(content3.Content, &updatedTask3) assert.Empty(t, updatedTask3.Blocks) assert.Equal(t, []string{"1"}, updatedTask3.BlockedBy) } func TestTaskUpdateToolBidirectionalBlockedBy(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "First task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Second task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) task3 := &task{ ID: "3", Subject: "Task 3", Description: "Third task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task3JSON, _ := sonic.MarshalString(task3) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "3.json"), Content: task3JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "3", "addBlockedBy": ["1", "2"]}`) assert.NoError(t, err) content3, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "3.json")}) var updatedTask3 task _ = sonic.UnmarshalString(content3.Content, &updatedTask3) assert.Empty(t, updatedTask3.Blocks) assert.Equal(t, []string{"1", "2"}, updatedTask3.BlockedBy) content1, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updatedTask1 task _ = sonic.UnmarshalString(content1.Content, &updatedTask1) assert.Equal(t, []string{"3"}, updatedTask1.Blocks) assert.Empty(t, updatedTask1.BlockedBy) content2, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "2.json")}) var updatedTask2 task _ = sonic.UnmarshalString(content2.Content, &updatedTask2) assert.Equal(t, []string{"3"}, updatedTask2.Blocks) assert.Empty(t, updatedTask2.BlockedBy) } func TestTaskUpdateToolBidirectionalWithNonExistentTask(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "First task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "1", "addBlocks": ["999"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "update Task #1 blocks failed") _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlockedBy": ["999"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "update Task #1 blockedBy failed") } func TestTaskUpdateToolCyclicDependencyDetection(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "First task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Second task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) task3 := &task{ ID: "3", Subject: "Task 3", Description: "Third task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task3JSON, _ := sonic.MarshalString(task3) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "3.json"), Content: task3JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "1", "addBlocks": ["1"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "cyclic dependency") _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlockedBy": ["1"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "cyclic dependency") _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlocks": ["2"]}`) assert.NoError(t, err) _, err = tool.InvokableRun(ctx, `{"taskId": "2", "addBlocks": ["1"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "cyclic dependency") _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlockedBy": ["2"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "cyclic dependency") _, err = tool.InvokableRun(ctx, `{"taskId": "2", "addBlocks": ["3"]}`) assert.NoError(t, err) _, err = tool.InvokableRun(ctx, `{"taskId": "3", "addBlocks": ["1"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "cyclic dependency") _, err = tool.InvokableRun(ctx, `{"taskId": "1", "addBlockedBy": ["3"]}`) assert.Error(t, err) assert.Contains(t, err.Error(), "cyclic dependency") content1, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updatedTask1 task _ = sonic.UnmarshalString(content1.Content, &updatedTask1) assert.Equal(t, []string{"2"}, updatedTask1.Blocks) assert.Empty(t, updatedTask1.BlockedBy) content2, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "2.json")}) var updatedTask2 task _ = sonic.UnmarshalString(content2.Content, &updatedTask2) assert.Equal(t, []string{"3"}, updatedTask2.Blocks) assert.Equal(t, []string{"1"}, updatedTask2.BlockedBy) content3, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "3.json")}) var updatedTask3 task _ = sonic.UnmarshalString(content3.Content, &updatedTask3) assert.Empty(t, updatedTask3.Blocks) assert.Equal(t, []string{"2"}, updatedTask3.BlockedBy) } func TestTaskUpdateToolDeleteCleansDependencies(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "First task", Status: taskStatusPending, Blocks: []string{"2", "3"}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Second task", Status: taskStatusPending, Blocks: []string{"3"}, BlockedBy: []string{"1"}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) task3 := &task{ ID: "3", Subject: "Task 3", Description: "Third task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{"1", "2"}, } task3JSON, _ := sonic.MarshalString(task3) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "3.json"), Content: task3JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) result, err := tool.InvokableRun(ctx, `{"taskId": "1", "status": "deleted"}`) assert.NoError(t, err) assert.Contains(t, result, "deleted") _, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) assert.Error(t, err) content2, err := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "2.json")}) assert.NoError(t, err) var updatedTask2 task _ = sonic.UnmarshalString(content2.Content, &updatedTask2) assert.Equal(t, []string{"3"}, updatedTask2.Blocks) assert.Empty(t, updatedTask2.BlockedBy) content3, err := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "3.json")}) assert.NoError(t, err) var updatedTask3 task _ = sonic.UnmarshalString(content3.Content, &updatedTask3) assert.Empty(t, updatedTask3.Blocks) assert.Equal(t, []string{"2"}, updatedTask3.BlockedBy) } func TestTaskUpdateToolAutoDeleteAllTasksWhenAllCompleted(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "First task", Status: taskStatusCompleted, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Second task", Status: taskStatusCompleted, Blocks: []string{}, BlockedBy: []string{}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) task3 := &task{ ID: "3", Subject: "Task 3", Description: "Third task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task3JSON, _ := sonic.MarshalString(task3) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "3.json"), Content: task3JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "3", "status": "completed"}`) assert.NoError(t, err) _, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) assert.Error(t, err) _, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "2.json")}) assert.Error(t, err) _, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "3.json")}) assert.Error(t, err) } func TestTaskUpdateToolNoDeleteWhenNotAllCompleted(t *testing.T) { ctx := context.Background() backend := newInMemoryBackend() baseDir := "/tmp/tasks" lock := &sync.Mutex{} task1 := &task{ ID: "1", Subject: "Task 1", Description: "First task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task1JSON, _ := sonic.MarshalString(task1) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "1.json"), Content: task1JSON}) task2 := &task{ ID: "2", Subject: "Task 2", Description: "Second task", Status: taskStatusPending, Blocks: []string{}, BlockedBy: []string{}, } task2JSON, _ := sonic.MarshalString(task2) _ = backend.Write(ctx, &WriteRequest{FilePath: filepath.Join(baseDir, "2.json"), Content: task2JSON}) tool := newTaskUpdateTool(backend, baseDir, lock) _, err := tool.InvokableRun(ctx, `{"taskId": "1", "status": "completed"}`) assert.NoError(t, err) _, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) assert.NoError(t, err) _, err = backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "2.json")}) assert.NoError(t, err) content1, _ := backend.Read(ctx, &ReadRequest{FilePath: filepath.Join(baseDir, "1.json")}) var updatedTask1 task _ = sonic.UnmarshalString(content1.Content, &updatedTask1) assert.Equal(t, taskStatusCompleted, updatedTask1.Status) } ================================================ FILE: adk/middlewares/reduction/consts.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package reduction provides middlewares to trim context and clear tool results. package reduction import "github.com/cloudwego/eino/adk/internal" const ( truncFmt = ` Output too large ({original_size}). Full output saved to: {file_path} Preview (first {preview_size}): {preview_first} Preview (last {preview_size}): {preview_last} ` truncFmtZh = ` 输出结果过大 ({original_size}). 完整输出保存到: {file_path} 预览 (前 {preview_size}): {preview_first} 预览 (后 {preview_size}): {preview_last} ` ) const ( clearWithOffloadingFmt = `Tool result saved to: {file_path} Use {read_tool_name} to view` clearWithOffloadingFmtZh = `工具结果已保存至: {file_path} 使用 {read_tool_name} 进行查看` clearWithoutOffloadingFmt = `[Old tool result content cleared]` clearWithoutOffloadingFmtZh = `[工具输出结果已清理]` ) const ( msgReducedFlag = "_reduction_mw_processed" msgReducedTokens = "_reduction_mw_tokens" ) func getTruncFmt() string { return internal.SelectPrompt(internal.I18nPrompts{ English: truncFmt, Chinese: truncFmtZh, }) } func getClearWithOffloadingFmt() string { return internal.SelectPrompt(internal.I18nPrompts{ English: clearWithOffloadingFmt, Chinese: clearWithOffloadingFmtZh, }) } func getClearWithoutOffloadingFmt() string { return internal.SelectPrompt(internal.I18nPrompts{ English: clearWithoutOffloadingFmt, Chinese: clearWithoutOffloadingFmtZh, }) } type scene int const ( sceneTruncation scene = 1 sceneClear scene = 2 ) ================================================ FILE: adk/middlewares/reduction/internal/clear_tool_result.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package internal provides middlewares to trim context and clear tool results. package internal import ( "context" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/schema" ) // ClearToolResultConfig configures the tool result clearing middleware. // This middleware clears old tool results when their total token count exceeds a threshold, // while protecting recent messages within a token budget. type ClearToolResultConfig struct { // ToolResultTokenThreshold is the threshold for total tool result tokens. // When the sum of all tool result tokens exceeds this threshold, old tool results // (outside the KeepRecentTokens range) will be replaced with a placeholder. // Token estimation uses a simple heuristic: character count / 4. // If 0, defaults to 20000. ToolResultTokenThreshold int // KeepRecentTokens is the token budget for recent messages to keep intact. // Messages within this token budget from the end will not have their tool results cleared, // even if the total tool result tokens exceed the threshold. // If 0, defaults to 40000. KeepRecentTokens int // ClearToolResultPlaceholder is the text to replace old tool results with. // If empty, defaults to "[Old tool result content cleared]". ClearToolResultPlaceholder string // TokenCounter is a custom function to estimate token count for a message. // If nil, uses the default counter (character count / 4). TokenCounter func(msg *schema.Message) int // ExcludeTools is a list of tool names whose results should never be cleared. ExcludeTools []string } // NewClearToolResult creates a new middleware that clears old tool results // based on token thresholds while protecting recent messages. func NewClearToolResult(ctx context.Context, config *ClearToolResultConfig) (adk.AgentMiddleware, error) { return adk.AgentMiddleware{ BeforeChatModel: newClearToolResult(ctx, config), }, nil } func newClearToolResult(ctx context.Context, config *ClearToolResultConfig) func(ctx context.Context, state *adk.ChatModelAgentState) error { if config == nil { config = &ClearToolResultConfig{} } // Set defaults toolResultTokenThreshold := config.ToolResultTokenThreshold if toolResultTokenThreshold == 0 { toolResultTokenThreshold = 20000 } keepRecentTokens := config.KeepRecentTokens if keepRecentTokens == 0 { keepRecentTokens = 40000 } placeholder := config.ClearToolResultPlaceholder if placeholder == "" { placeholder = "[Old tool result content cleared]" } // Set token estimator counter := config.TokenCounter if counter == nil { counter = defaultTokenCounter } return func(ctx context.Context, state *adk.ChatModelAgentState) error { return reduceByTokens(state, toolResultTokenThreshold, keepRecentTokens, placeholder, counter, config.ExcludeTools) } } // defaultTokenCounter estimates token count using character count / 4 // This is a simple heuristic that works reasonably well for most languages func defaultTokenCounter(msg *schema.Message) int { count := len(msg.Content) // Also count tool call arguments if present for _, tc := range msg.ToolCalls { count += len(tc.Function.Arguments) } // Estimate: roughly 4 characters per token return (count + 3) / 4 } // reduceByTokens reduces context based on tool result token threshold and recent message protection. // It clears old tool results when: // 1. The total tokens of all tool results exceed toolResultTokenThreshold // 2. Only tool results outside the keepRecentTokens range (from the end) are cleared func reduceByTokens(state *adk.ChatModelAgentState, toolResultTokenThreshold, keepRecentTokens int, placeholder string, counter func(*schema.Message) int, excludedTools []string) error { if len(state.Messages) == 0 { return nil } // Step 1: Calculate total tool result tokens totalToolResultTokens := 0 for _, msg := range state.Messages { if msg.Role == schema.Tool && msg.Content != placeholder { totalToolResultTokens += counter(msg) } } // If total tool result tokens are under the threshold, no reduction needed if totalToolResultTokens <= toolResultTokenThreshold { return nil } // Step 2: Calculate the index from which to protect recent messages // We need to find the starting index where cumulative tokens from the end <= keepRecentTokens recentStartIdx := len(state.Messages) cumulativeTokens := 0 for i := len(state.Messages) - 1; i >= 0; i-- { msgTokens := counter(state.Messages[i]) if cumulativeTokens+msgTokens > keepRecentTokens { // Adding this message would exceed the budget, so stop here recentStartIdx = i break } cumulativeTokens += msgTokens recentStartIdx = i } // Step 3: Clear tool results outside the protected range (before recentStartIdx) for i := 0; i < recentStartIdx; i++ { msg := state.Messages[i] if msg.Role == schema.Tool && msg.Content != placeholder && !excluded(msg.ToolName, excludedTools) { msg.Content = placeholder } } return nil } func excluded(name string, exclude []string) bool { for _, ex := range exclude { if name == ex { return true } } return false } ================================================ FILE: adk/middlewares/reduction/internal/clear_tool_result_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "context" "fmt" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/schema" ) func Test_reduceByTokens(t *testing.T) { type args struct { state *adk.ChatModelAgentState toolResultTokenThreshold int keepRecentTokens int placeholder string estimator func(*schema.Message) int } tests := []struct { name string args args wantErr assert.ErrorAssertionFunc validateState func(*testing.T, *adk.ChatModelAgentState) }{ { name: "no reduction when tool result tokens under threshold", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello"), schema.AssistantMessage("hi", nil), schema.ToolMessage("short tool result", "call-1", schema.WithToolName("tool1")), }, }, toolResultTokenThreshold: 100, keepRecentTokens: 500, placeholder: "[Old tool result content cleared]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { assert.Equal(t, "short tool result", state.Messages[2].Content) }, }, { name: "clear old tool results when total exceeds threshold", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("msg1"), schema.ToolMessage(strings.Repeat("a", 40), "call-1", schema.WithToolName("tool1")), // ~10 tokens (old) schema.UserMessage("msg2"), schema.ToolMessage(strings.Repeat("b", 40), "call-2", schema.WithToolName("tool2")), // ~10 tokens (old) schema.UserMessage("msg3"), schema.ToolMessage(strings.Repeat("c", 40), "call-3", schema.WithToolName("tool3")), // ~10 tokens (recent, protected) }, }, toolResultTokenThreshold: 20, keepRecentTokens: 10, placeholder: "[Old tool result content cleared]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { assert.Equal(t, "[Old tool result content cleared]", state.Messages[1].Content) assert.Equal(t, "[Old tool result content cleared]", state.Messages[3].Content) assert.Equal(t, strings.Repeat("c", 40), state.Messages[5].Content) }, }, { name: "protect recent messages even when tool results exceed threshold", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("old msg"), schema.ToolMessage(strings.Repeat("x", 100), "call-1", schema.WithToolName("tool1")), // ~25 tokens (old) schema.UserMessage("recent msg"), schema.ToolMessage(strings.Repeat("x", 100), "call-2", schema.WithToolName("tool2")), // ~25 tokens (recent, protected) }, }, toolResultTokenThreshold: 10, keepRecentTokens: 20, placeholder: "[Old tool result content cleared]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { // Total tool result tokens = 50, exceeds threshold of 10 // But last 200 tokens are protected (includes last 2 messages) // So only the first tool result should be cleared assert.Equal(t, "[Old tool result content cleared]", state.Messages[1].Content) assert.Equal(t, strings.Repeat("x", 100), state.Messages[3].Content) }, }, { name: "custom placeholder text", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("msg"), schema.ToolMessage(strings.Repeat("x", 100), "call-1", schema.WithToolName("tool1")), schema.UserMessage(strings.Repeat("x", 100)), }, }, toolResultTokenThreshold: 10, keepRecentTokens: 20, placeholder: "[历史工具结果已清除]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { assert.Equal(t, "[历史工具结果已清除]", state.Messages[1].Content) }, }, { name: "no tool messages", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("msg 1"), schema.AssistantMessage("response 1", nil), schema.UserMessage("msg 2"), schema.AssistantMessage("response 2", nil), }, }, toolResultTokenThreshold: 10, keepRecentTokens: 10, placeholder: "[Old tool result content cleared]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { // All messages should remain unchanged assert.Equal(t, "msg 1", state.Messages[0].Content) assert.Equal(t, "response 1", state.Messages[1].Content) assert.Equal(t, "msg 2", state.Messages[2].Content) assert.Equal(t, "response 2", state.Messages[3].Content) }, }, { name: "empty messages", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{}, }, toolResultTokenThreshold: 100, keepRecentTokens: 500, placeholder: "[Old tool result content cleared]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { assert.Empty(t, state.Messages) }, }, { name: "custom token estimator - word count", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello world"), schema.ToolMessage("this is a long tool result", "call-1", schema.WithToolName("tool1")), // 6 words (old) schema.UserMessage("another message"), schema.ToolMessage("recent tool result here", "call-2", schema.WithToolName("tool2")), // 4 words (recent) }, }, toolResultTokenThreshold: 9, // 10 words total threshold keepRecentTokens: 5, // 15 words protection budget placeholder: "[Old tool result content cleared]", estimator: func(msg *schema.Message) int { if msg.Content == "" { return 0 } words := 1 for _, ch := range msg.Content { if ch == ' ' { words++ } } return words }, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { assert.Equal(t, "[Old tool result content cleared]", state.Messages[1].Content) assert.Equal(t, "recent tool result here", state.Messages[3].Content) }, }, { name: "already cleared results are not counted", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("msg1"), schema.ToolMessage("[Old tool result content cleared]", "call-1", schema.WithToolName("tool1")), // Already cleared schema.UserMessage("msg2"), schema.ToolMessage(strings.Repeat("a", 100), "call-2", schema.WithToolName("tool2")), // New long result }, }, toolResultTokenThreshold: 10, keepRecentTokens: 20, placeholder: "[Old tool result content cleared]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { // Only the new long result counts toward the threshold // Both should have placeholder assert.Equal(t, "[Old tool result content cleared]", state.Messages[1].Content) assert.Equal(t, strings.Repeat("a", 100), state.Messages[3].Content) }, }, { name: "all tool results within protected range", args: args{ state: &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("msg1"), schema.ToolMessage(strings.Repeat("a", 40), "call-1", schema.WithToolName("tool1")), // ~10 tokens schema.UserMessage("msg2"), schema.ToolMessage(strings.Repeat("b", 40), "call-2", schema.WithToolName("tool2")), // ~10 tokens }, }, toolResultTokenThreshold: 10, // Low threshold (will exceed) keepRecentTokens: 1000, // Very high protection (protects all) placeholder: "[Old tool result content cleared]", estimator: defaultTokenCounter, }, wantErr: assert.NoError, validateState: func(t *testing.T, state *adk.ChatModelAgentState) { // All messages are within protected range, nothing should be cleared assert.Equal(t, strings.Repeat("a", 40), state.Messages[1].Content) assert.Equal(t, strings.Repeat("b", 40), state.Messages[3].Content) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := reduceByTokens(tt.args.state, tt.args.toolResultTokenThreshold, tt.args.keepRecentTokens, tt.args.placeholder, tt.args.estimator, []string{}) tt.wantErr(t, err, fmt.Sprintf("reduceByTokens(%v, %v, %v, %v)", tt.args.state, tt.args.toolResultTokenThreshold, tt.args.keepRecentTokens, tt.args.placeholder)) if tt.validateState != nil { tt.validateState(t, tt.args.state) } }) } } func Test_newClearToolResult(t *testing.T) { ctx := context.Background() t.Run("nil config uses defaults", func(t *testing.T) { fn := newClearToolResult(ctx, nil) assert.NotNil(t, fn) // Test that function works with nil config (uses defaults) state := &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello"), schema.ToolMessage("short result", "call-1", schema.WithToolName("tool1")), }, } err := fn(ctx, state) assert.NoError(t, err) // Default threshold is 20000, so short result should not be cleared assert.Equal(t, "short result", state.Messages[1].Content) }) t.Run("empty config uses defaults", func(t *testing.T) { fn := newClearToolResult(ctx, &ClearToolResultConfig{}) assert.NotNil(t, fn) state := &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello"), schema.ToolMessage("short result", "call-1", schema.WithToolName("tool1")), }, } err := fn(ctx, state) assert.NoError(t, err) assert.Equal(t, "short result", state.Messages[1].Content) }) } ================================================ FILE: adk/middlewares/reduction/internal/large_tool_result.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "bufio" "context" "errors" "fmt" "io" "strings" "unicode/utf8" "github.com/slongfield/pyfmt" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) const ( tooLargeToolMessage = `Tool result too large, the result of this tool call {tool_call_id} was saved in the filesystem at this path: {file_path} You can read the result from the filesystem by using the '{read_file_tool_name}' tool, but make sure to only read part of the result at a time. You can do this by specifying an offset and limit in the '{read_file_tool_name}' tool call. For example, to read the first 100 lines, you can use the '{read_file_tool_name}' tool with offset=0 and limit=100. Here are the first 10 lines of the result: {content_sample}` tooLargeToolMessageChinese = `工具结果过大,此工具调用 {tool_call_id} 的结果已保存到文件系统的以下路径:{file_path} 你可以使用 '{read_file_tool_name}' 工具从文件系统读取结果,但请确保每次只读取部分结果。 你可以通过在 '{read_file_tool_name}' 工具调用中指定 offset 和 limit 来实现。 例如,要读取前 100 行,你可以使用 '{read_file_tool_name}' 工具,设置 offset=0 和 limit=100。 以下是结果的前 10 行: {content_sample}` ) type toolResultOffloadingConfig struct { Backend Backend ReadFileToolName string TokenLimit int PathGenerator func(ctx context.Context, input *compose.ToolInput) (string, error) TokenCounter func(msg *schema.Message) int } func newToolResultOffloading(_ context.Context, config *toolResultOffloadingConfig) compose.ToolMiddleware { offloading := &toolResultOffloading{ backend: config.Backend, tokenLimit: config.TokenLimit, pathGenerator: config.PathGenerator, toolName: config.ReadFileToolName, counter: config.TokenCounter, } if offloading.tokenLimit == 0 { offloading.tokenLimit = 20000 } if offloading.pathGenerator == nil { offloading.pathGenerator = func(ctx context.Context, input *compose.ToolInput) (string, error) { return fmt.Sprintf("/large_tool_result/%s", input.CallID), nil } } if len(offloading.toolName) == 0 { offloading.toolName = "read_file" } if offloading.counter == nil { offloading.counter = defaultTokenCounter } return compose.ToolMiddleware{ Invokable: offloading.invoke, Streamable: offloading.stream, } } type toolResultOffloading struct { backend Backend tokenLimit int pathGenerator func(ctx context.Context, input *compose.ToolInput) (string, error) toolName string counter func(msg *schema.Message) int } func (t *toolResultOffloading) invoke(endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { output, err := endpoint(ctx, input) if err != nil { return nil, err } result, err := t.handleResult(ctx, output.Result, input) if err != nil { return nil, err } return &compose.ToolOutput{Result: result}, nil } } func (t *toolResultOffloading) stream(endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { output, err := endpoint(ctx, input) if err != nil { return nil, err } result, err := concatString(output.Result) if err != nil { return nil, err } result, err = t.handleResult(ctx, result, input) if err != nil { return nil, err } return &compose.StreamToolOutput{Result: schema.StreamReaderFromArray([]string{result})}, nil } } func (t *toolResultOffloading) handleResult(ctx context.Context, result string, input *compose.ToolInput) (string, error) { if t.counter(schema.ToolMessage(result, input.CallID, schema.WithToolName(input.Name))) > t.tokenLimit*4 { path, err := t.pathGenerator(ctx, input) if err != nil { return "", err } nResult := formatToolMessage(result) tpl := internal.SelectPrompt(internal.I18nPrompts{ English: tooLargeToolMessage, Chinese: tooLargeToolMessageChinese, }) nResult, err = pyfmt.Fmt(tpl, map[string]any{ "tool_call_id": input.CallID, "file_path": path, "content_sample": nResult, "read_file_tool_name": t.toolName, }) if err != nil { return "", err } err = t.backend.Write(ctx, &filesystem.WriteRequest{ FilePath: path, Content: result, }) if err != nil { return "", err } return nResult, nil } return result, nil } func concatString(sr *schema.StreamReader[string]) (string, error) { if sr == nil { return "", errors.New("stream is nil") } sb := strings.Builder{} for { str, err := sr.Recv() if errors.Is(err, io.EOF) { return sb.String(), nil } if err != nil { return "", err } sb.WriteString(str) } } func formatToolMessage(s string) string { reader := bufio.NewScanner(strings.NewReader(s)) var b strings.Builder lineNum := 1 for reader.Scan() { if lineNum > 10 { break } line := reader.Text() if utf8.RuneCountInString(line) > 1000 { runes := []rune(line) line = string(runes[:1000]) } b.WriteString(fmt.Sprintf("%d: %s\n", lineNum, line)) lineNum++ } return b.String() } ================================================ FILE: adk/middlewares/reduction/internal/large_tool_result_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "context" "errors" "fmt" "io" "strings" "testing" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // mockBackend is a simple in-memory backend for testing type mockBackend struct { files map[string]string } func newMockBackend() *mockBackend { return &mockBackend{ files: make(map[string]string), } } func (m *mockBackend) Write(_ context.Context, wr *filesystem.WriteRequest) error { m.files[wr.FilePath] = wr.Content return nil } func TestToolResultOffloading_SmallResult(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 100, // Small limit for testing } middleware := newToolResultOffloading(ctx, config) // Create a mock endpoint that returns a small result smallResult := "This is a small result" mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: smallResult}, nil } // Wrap the endpoint with the middleware wrappedEndpoint := middleware.Invokable(mockEndpoint) // Execute input := &compose.ToolInput{ Name: "test_tool", CallID: "call_123", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Small result should pass through unchanged if output.Result != smallResult { t.Errorf("expected result %q, got %q", smallResult, output.Result) } // No file should be written if len(backend.files) != 0 { t.Errorf("expected no files to be written, got %d files", len(backend.files)) } } func TestToolResultOffloading_LargeResult(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, // Very small limit to trigger offloading } middleware := newToolResultOffloading(ctx, config) // Create a large result (more than 10 * 4 = 40 bytes) largeResult := strings.Repeat("This is a long line of text that will exceed the token limit.\n", 10) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_456", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Result should be replaced with a message if !strings.Contains(output.Result, "Tool result too large") { t.Errorf("expected result to contain 'Tool result too large', got %q", output.Result) } if !strings.Contains(output.Result, "call_456") { t.Errorf("expected result to contain call ID 'call_456', got %q", output.Result) } if !strings.Contains(output.Result, "/large_tool_result/call_456") { t.Errorf("expected result to contain file path, got %q", output.Result) } // File should be written if len(backend.files) != 1 { t.Fatalf("expected 1 file to be written, got %d files", len(backend.files)) } savedContent, ok := backend.files["/large_tool_result/call_456"] if !ok { t.Fatalf("expected file at /large_tool_result/call_456, got files: %v", backend.files) } if savedContent != largeResult { t.Errorf("saved content doesn't match original result") } } func TestToolResultOffloading_CustomPathGenerator(t *testing.T) { ctx := context.Background() backend := newMockBackend() customPath := "/custom/path/result.txt" config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, PathGenerator: func(ctx context.Context, input *compose.ToolInput) (string, error) { return customPath, nil }, } middleware := newToolResultOffloading(ctx, config) largeResult := strings.Repeat("Large content ", 100) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_789", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Check custom path is used if !strings.Contains(output.Result, customPath) { t.Errorf("expected result to contain custom path %q, got %q", customPath, output.Result) } // File should be written to custom path savedContent, ok := backend.files[customPath] if !ok { t.Fatalf("expected file at %q, got files: %v", customPath, backend.files) } if savedContent != largeResult { t.Errorf("saved content doesn't match original result") } } func TestToolResultOffloading_PathGeneratorError(t *testing.T) { ctx := context.Background() backend := newMockBackend() expectedErr := errors.New("path generation failed") config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, PathGenerator: func(ctx context.Context, input *compose.ToolInput) (string, error) { return "", expectedErr }, } middleware := newToolResultOffloading(ctx, config) largeResult := strings.Repeat("Large content ", 100) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, expectedErr) { t.Errorf("expected error %v, got %v", expectedErr, err) } } func TestToolResultOffloading_EndpointError(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 100, } middleware := newToolResultOffloading(ctx, config) expectedErr := errors.New("endpoint execution failed") mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return nil, expectedErr } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_endpoint_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, expectedErr) { t.Errorf("expected error %v, got %v", expectedErr, err) } } func TestToolResultOffloading_DefaultTokenLimit(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 0, // Should default to 20000 } middleware := newToolResultOffloading(ctx, config) // Create a result smaller than 20000 * 4 = 80000 bytes smallResult := strings.Repeat("x", 1000) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: smallResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_default", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Should pass through unchanged if output.Result != smallResult { t.Errorf("expected result to pass through unchanged") } // No file should be written if len(backend.files) != 0 { t.Errorf("expected no files to be written, got %d files", len(backend.files)) } } func TestToolResultOffloading_Stream(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, } middleware := newToolResultOffloading(ctx, config) // Create a streaming endpoint that returns large content largeResult := strings.Repeat("Large streaming content ", 100) mockStreamEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { // Split the result into chunks chunks := []string{largeResult[:len(largeResult)/2], largeResult[len(largeResult)/2:]} return &compose.StreamToolOutput{ Result: schema.StreamReaderFromArray(chunks), }, nil } wrappedEndpoint := middleware.Streamable(mockStreamEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_stream", } output, err := wrappedEndpoint(ctx, input) if err != nil { t.Fatalf("unexpected error: %v", err) } // Read the stream var result strings.Builder for { chunk, err := output.Result.Recv() if errors.Is(err, io.EOF) { break } if err != nil { t.Fatalf("error reading stream: %v", err) } result.WriteString(chunk) } resultStr := result.String() // Result should be replaced with a message if !strings.Contains(resultStr, "Tool result too large") { t.Errorf("expected result to contain 'Tool result too large', got %q", resultStr) } if !strings.Contains(resultStr, "call_stream") { t.Errorf("expected result to contain call ID 'call_stream', got %q", resultStr) } // File should be written if len(backend.files) != 1 { t.Fatalf("expected 1 file to be written, got %d files", len(backend.files)) } savedContent, ok := backend.files["/large_tool_result/call_stream"] if !ok { t.Fatalf("expected file at /large_tool_result/call_stream, got files: %v", backend.files) } if savedContent != largeResult { t.Errorf("saved content doesn't match original result") } } func TestToolResultOffloading_StreamError(t *testing.T) { ctx := context.Background() backend := newMockBackend() config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, } middleware := newToolResultOffloading(ctx, config) expectedErr := errors.New("stream endpoint failed") mockStreamEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { return nil, expectedErr } wrappedEndpoint := middleware.Streamable(mockStreamEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_stream_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, expectedErr) { t.Errorf("expected error %v, got %v", expectedErr, err) } } func TestFormatToolMessage(t *testing.T) { tests := []struct { name string input string expected string }{ { name: "single line", input: "single line", expected: "1: single line\n", }, { name: "multiple lines", input: "line1\nline2\nline3", expected: "1: line1\n2: line2\n3: line3\n", }, { name: "more than 10 lines", input: "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12", expected: "1: 1\n2: 2\n3: 3\n4: 4\n5: 5\n6: 6\n7: 7\n8: 8\n9: 9\n10: 10\n", }, { name: "long line truncation", input: strings.Repeat("a", 1500), expected: fmt.Sprintf("1: %s\n", strings.Repeat("a", 1000)), }, { name: "unicode characters", input: "你好世界\n测试", expected: "1: 你好世界\n2: 测试\n", }, { name: "long unicode line", input: strings.Repeat("你", 1500), expected: fmt.Sprintf("1: %s\n", strings.Repeat("你", 1000)), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := formatToolMessage(tt.input) if result != tt.expected { t.Errorf("formatToolMessage() = %q, want %q", result, tt.expected) } }) } } func TestConcatString(t *testing.T) { tests := []struct { name string chunks []string expected string expectError bool }{ { name: "single chunk", chunks: []string{"hello"}, expected: "hello", }, { name: "multiple chunks", chunks: []string{"hello", " ", "world"}, expected: "hello world", }, { name: "empty chunks", chunks: []string{"", "", ""}, expected: "", }, { name: "mixed chunks", chunks: []string{"a", "", "b", "c"}, expected: "abc", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sr := schema.StreamReaderFromArray(tt.chunks) result, err := concatString(sr) if tt.expectError { if err == nil { t.Error("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if result != tt.expected { t.Errorf("concatString() = %q, want %q", result, tt.expected) } }) } // Test nil stream t.Run("nil stream", func(t *testing.T) { _, err := concatString(nil) if err == nil { t.Error("expected error for nil stream, got nil") } if !strings.Contains(err.Error(), "stream is nil") { t.Errorf("expected 'stream is nil' error, got %v", err) } }) } func TestToolResultOffloading_BackendWriteError(t *testing.T) { ctx := context.Background() // Create a backend that fails on write backend := &failingBackend{ writeErr: errors.New("write failed"), } config := &toolResultOffloadingConfig{ Backend: backend, TokenLimit: 10, } middleware := newToolResultOffloading(ctx, config) largeResult := strings.Repeat("Large content ", 100) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: largeResult}, nil } wrappedEndpoint := middleware.Invokable(mockEndpoint) input := &compose.ToolInput{ Name: "test_tool", CallID: "call_write_error", } _, err := wrappedEndpoint(ctx, input) if err == nil { t.Fatal("expected error, got nil") } if !strings.Contains(err.Error(), "write failed") { t.Errorf("expected 'write failed' error, got %v", err) } } // failingBackend is a mock backend that can be configured to fail type failingBackend struct { writeErr error } func (f *failingBackend) Write(context.Context, *filesystem.WriteRequest) error { if f.writeErr != nil { return f.writeErr } return nil } ================================================ FILE: adk/middlewares/reduction/internal/tool_result.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "context" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // Backend defines the interface provided by the user to implement file storage. // It is used to save the content of large tool results to a persistent storage. type Backend interface { Write(context.Context, *filesystem.WriteRequest) error } // ToolResultConfig configures the tool result reduction middleware. type ToolResultConfig struct { // ClearingTokenThreshold is the threshold for the total token count of all tool results. // When the sum of all tool result tokens exceeds this threshold, old tool results // (outside the KeepRecentTokens range) will be replaced with a placeholder. // Token estimation uses a simple heuristic: character count / 4. // optional, 20000 by default ClearingTokenThreshold int // KeepRecentTokens is the token budget for recent messages to keep intact. // Messages within this token budget from the end will not have their tool results cleared, // even if the total tool result tokens exceed the threshold. // optional, 40000 by default KeepRecentTokens int // ClearToolResultPlaceholder is the text to replace old tool results with. // optional, "[Old tool result content cleared]" by default ClearToolResultPlaceholder string // TokenCounter is a custom function to estimate token count for a message. // optional, uses the default counter (character count / 4) if nil TokenCounter func(msg *schema.Message) int // ExcludeTools is a list of tool names whose results should never be cleared. // optional ExcludeTools []string // Backend is the storage backend for offloaded tool results. // required Backend Backend // OffloadingTokenLimit is the token threshold for a single tool result to trigger offloading. // When a single tool result exceeds OffloadingTokenLimit * 4 characters, it will be // offloaded to the filesystem. // optional, 20000 by default OffloadingTokenLimit int // ReadFileToolName is the name of the tool that LLM should use to read offloaded content. // This name will be included in the summary message sent to the LLM. // optional, "read_file" by default // // NOTE: If you are using the filesystem middleware, the read_file tool name // is exactly "read_file", which matches the default value. ReadFileToolName string // PathGenerator generates the write path for offloaded results. // optional, "/large_tool_result/{ToolCallID}" by default PathGenerator func(ctx context.Context, input *compose.ToolInput) (string, error) } // NewToolResultMiddleware creates a tool result reduction middleware. // This middleware combines two strategies to manage tool result tokens: // // 1. Clearing: Replaces old tool results with a placeholder when the total // tool result tokens exceed the threshold, while protecting recent messages. // // 2. Offloading: Writes large individual tool results to the filesystem and // returns a summary message guiding the LLM to read the full content. // // NOTE: If you are using the filesystem middleware (github.com/cloudwego/eino/adk/middlewares/filesystem), // this functionality is already included by default. Set Config.WithoutLargeToolResultOffloading = true // in the filesystem middleware if you want to use this middleware separately instead. // // NOTE: This middleware only handles offloading results to the filesystem. // You MUST also provide a read_file tool to your agent, otherwise the agent // will not be able to read the offloaded content. You can either: // - Use the filesystem middleware (github.com/cloudwego/eino/adk/middlewares/filesystem) // which provides the read_file tool automatically, OR // - Implement your own read_file tool that reads from the same Backend func NewToolResultMiddleware(ctx context.Context, cfg *ToolResultConfig) (adk.AgentMiddleware, error) { bc := newClearToolResult(ctx, &ClearToolResultConfig{ ToolResultTokenThreshold: cfg.ClearingTokenThreshold, KeepRecentTokens: cfg.KeepRecentTokens, ClearToolResultPlaceholder: cfg.ClearToolResultPlaceholder, TokenCounter: cfg.TokenCounter, ExcludeTools: cfg.ExcludeTools, }) tm := newToolResultOffloading(ctx, &toolResultOffloadingConfig{ Backend: cfg.Backend, ReadFileToolName: cfg.ReadFileToolName, TokenLimit: cfg.OffloadingTokenLimit, PathGenerator: cfg.PathGenerator, }) return adk.AgentMiddleware{ BeforeChatModel: bc, WrapToolCall: tm, }, nil } ================================================ FILE: adk/middlewares/reduction/legacy.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package reduction import "github.com/cloudwego/eino/adk/middlewares/reduction/internal" // Package reduction provides historical compatibility exports for reduction middleware APIs. // // DEPRECATED: All top-level exports in this file are maintained exclusively for backward compatibility. // New reduction middleware implementations are now developed and maintained in this package. // It is STRONGLY RECOMMENDED that new code directly use the New instead. // // Existing code relying on these exports will continue to work indefinitely, // but no new features or bug fixes will be backported to this compatibility layer. type ( ClearToolResultConfig = internal.ClearToolResultConfig ToolResultConfig = internal.ToolResultConfig Backend = internal.Backend ) var ( // NewClearToolResult creates a new middleware that clears old tool results // based on token thresholds while protecting recent messages. // // Deprecated: Use New instead, which provides a more comprehensive tool result reduction // middleware with both truncation and clearing strategies. New returns a ChatModelAgentMiddleware // for better context propagation through wrapper methods. NewClearToolResult = internal.NewClearToolResult // NewToolResultMiddleware creates a tool result reduction middleware. // This middleware combines two strategies to manage tool result tokens: // // 1. Clearing: Replaces old tool results with a placeholder when the total // tool result tokens exceed the threshold, while protecting recent messages. // // 2. Offloading: Writes large individual tool results to the filesystem and // returns a summary message guiding the LLM to read the full content. // // NOTE: If you are using the filesystem middleware (github.com/cloudwego/eino/adk/middlewares/filesystem), // this functionality is already included by default. Set Config.WithoutLargeToolResultOffloading = true // in the filesystem middleware if you want to use this middleware separately instead. // // NOTE: This middleware only handles offloading results to the filesystem. // You MUST also provide a read_file tool to your agent, otherwise the agent // will not be able to read the offloaded content. You can either: // - Use the filesystem middleware (github.com/cloudwego/eino/adk/middlewares/filesystem) // which provides the read_file tool automatically, OR // - Implement your own read_file tool that reads from the same Backend // // Deprecated: Use New instead, which provides a more comprehensive tool result reduction // middleware with both truncation and clearing strategies, per-tool configuration support, // and returns a ChatModelAgentMiddleware for better context propagation through wrapper methods. NewToolResultMiddleware = internal.NewToolResultMiddleware ) ================================================ FILE: adk/middlewares/reduction/reduction.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package reduction import ( "context" "fmt" "io" "path/filepath" "strings" "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/slongfield/pyfmt" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) // Config is the configuration for tool reduction middleware. // This middleware manages tool outputs in two phases to optimize context usage: // // 1. Truncation Phase: // Triggered immediately after a tool execution completes. // If the tool output length exceeds MaxLengthForTrunc, the full content is saved // to the configured Backend, and the tool output is replaced with a truncated notice. // This prevents immediate context overflow from a single large tool output. // // 2. Clear Phase: // Triggered before sending messages to the model (in BeforeModelRewriteState). // If the total token count exceeds MaxTokensForClear, the middleware iterates through // historical messages. Based on RootDir and ClearRetentionSuffixLimit, it offloads tool call arguments and results // to the Backend to reduce token usage, keeping the conversation within limits while retaining access to the // important information. After all, ClearPostProcess will be called, which you could save or notify current state. type Config struct { // Backend is the storage backend where truncated content will be saved. // Required. Backend Backend // SkipTruncation skip truncating. SkipTruncation bool // SkipClear skip clearing. SkipClear bool // ReadFileToolName is tool name used to retrieve from file. // After offloading content to file, you should give agent the same tool to retrieve content. // Required. Default is "read_file". ReadFileToolName string // RootDir root dir to save truncated/cleared content. // Required. // Default is /tmp, truncated content saves to /tmp/trunc/{tool_call_id}, cleared content saves to /tmp/clear/{tool_call_id} RootDir string // MaxLengthForTrunc is the maximum allowed length of the tool output. // If the output exceeds this length, it will be truncated. // Required. Default is 50000. MaxLengthForTrunc int // TokenCounter is used to count the number of tokens in the conversation messages. // It is used to determine when to trigger clearing based on token usage, and token usage after clearing. // Required. TokenCounter func(ctx context.Context, msg []adk.Message, tools []*schema.ToolInfo) (int64, error) // MaxTokensForClear is the maximum number of tokens allowed in the conversation before clearing is attempted. // Required. Default is 30000. MaxTokensForClear int64 // ClearRetentionSuffixLimit is the number of most recent messages to retain without clearing. // This ensures the model has some immediate context. // Optional. Default is 1. ClearRetentionSuffixLimit int // ClearPostProcess is clear post process handler. // Optional. ClearPostProcess func(ctx context.Context, state *adk.ChatModelAgentState) context.Context // ToolConfig is the specific configuration that applies to tools by name. // This configuration takes precedence over GeneralConfig for the specified tools. // Optional. ToolConfig map[string]*ToolReductionConfig } type ToolReductionConfig struct { // Backend is the storage backend where truncated content will be saved. // Required. Backend Backend // SkipTruncation skip truncation for this tool. SkipTruncation bool // TruncHandler is used to process tool call results during truncation. // Optional. Default using defaultTruncHandler when SkipTruncation is false but TruncHandler is nil. TruncHandler func(ctx context.Context, detail *ToolDetail) (*TruncResult, error) // SkipClear skip clear for this tool. SkipClear bool // ClearHandler is used to process tool call arguments and results during clearing. // Optional. Default using defaultClearHandler when SkipClear is false but ClearHandler is nil. ClearHandler func(ctx context.Context, detail *ToolDetail) (*ClearResult, error) } type ToolDetail struct { // ToolContext provides metadata about the tool call (e.g., tool name, call ID). ToolContext *adk.ToolContext // ToolArgument contains the arguments passed to the tool. ToolArgument *schema.ToolArgument // ToolResult contains the output returned by the tool. ToolResult *schema.ToolResult } type TruncResult struct { // NeedTrunc indicates whether the tool result should be truncated. NeedTrunc bool // ToolResult contains the result returned by the tool after trunc // Required when NeedTrunc is true. ToolResult *schema.ToolResult // NeedOffload indicates whether the tool result should be offloaded. NeedOffload bool // OffloadFilePath is the path where the offloaded content should be stored. // This path is typically relative to the backend's root. // Required when NeedOffload is true. OffloadFilePath string // OffloadContent is the actual content to be written to the storage backend. // Required when NeedOffload is true. OffloadContent string } // ClearResult contains the result of the Handler's decision. type ClearResult struct { // NeedClear indicates whether the tool argument and result should be cleared. NeedClear bool // ToolArgument contains the arguments passed to the tool after clear. // Required when NeedClear is true. ToolArgument *schema.ToolArgument // ToolResult contains the output returned by the tool after clear. // Required when NeedClear is true ToolResult *schema.ToolResult // NeedOffload indicates whether the tool argument and result should be offloaded. NeedOffload bool // OffloadFilePath is the path where the offloaded content should be stored. // This path is typically relative to the backend's root. // Required when NeedOffload is true. OffloadFilePath string // OffloadContent is the actual content to be written to the storage backend. // Required when NeedOffload is true. OffloadContent string } func (t *Config) copyAndFillDefaults() (*Config, error) { cfg := &Config{ Backend: t.Backend, SkipTruncation: t.SkipTruncation, SkipClear: t.SkipClear, ReadFileToolName: t.ReadFileToolName, RootDir: t.RootDir, MaxLengthForTrunc: t.MaxLengthForTrunc, TokenCounter: t.TokenCounter, MaxTokensForClear: t.MaxTokensForClear, ClearRetentionSuffixLimit: t.ClearRetentionSuffixLimit, ClearPostProcess: t.ClearPostProcess, } if cfg.TokenCounter == nil { cfg.TokenCounter = defaultTokenCounter } if cfg.ClearRetentionSuffixLimit == 0 { cfg.ClearRetentionSuffixLimit = 1 } if cfg.ReadFileToolName == "" { cfg.ReadFileToolName = "read_file" } if cfg.RootDir == "" { cfg.RootDir = "/tmp" } if cfg.MaxLengthForTrunc == 0 { cfg.MaxLengthForTrunc = 50000 } if t.ToolConfig != nil { cfg.ToolConfig = make(map[string]*ToolReductionConfig, len(t.ToolConfig)) for toolName, trc := range t.ToolConfig { cpConfig := &ToolReductionConfig{ Backend: trc.Backend, SkipTruncation: trc.SkipTruncation, SkipClear: trc.SkipClear, TruncHandler: trc.TruncHandler, ClearHandler: trc.ClearHandler, } cfg.ToolConfig[toolName] = cpConfig } } return cfg, nil } // New creates tool reduction middleware from config func New(_ context.Context, config *Config) (adk.ChatModelAgentMiddleware, error) { var err error if config == nil { return nil, fmt.Errorf("config must not be nil") } if config.Backend == nil && !config.SkipTruncation { return nil, fmt.Errorf("backend must be set when not skipping truncation") } config, err = config.copyAndFillDefaults() if err != nil { return nil, err } defaultReductionConfig := &ToolReductionConfig{ Backend: config.Backend, SkipTruncation: config.SkipTruncation, SkipClear: config.SkipClear, } if !defaultReductionConfig.SkipTruncation { defaultReductionConfig.TruncHandler = defaultTruncHandler(config.RootDir, config.MaxLengthForTrunc) } if !defaultReductionConfig.SkipClear { defaultReductionConfig.ClearHandler = defaultClearHandler(config.RootDir, config.Backend != nil, config.ReadFileToolName) } return &toolReductionMiddleware{ config: config, defaultConfig: defaultReductionConfig, }, nil } type toolReductionMiddleware struct { adk.BaseChatModelAgentMiddleware config *Config defaultConfig *ToolReductionConfig } func (t *toolReductionMiddleware) getToolConfig(toolName string, sc scene) *ToolReductionConfig { if t.config.ToolConfig != nil { if cfg, ok := t.config.ToolConfig[toolName]; ok { if (sc == sceneTruncation && !cfg.SkipTruncation && cfg.TruncHandler == nil) || (sc == sceneClear && !cfg.SkipClear && cfg.ClearHandler == nil) { return t.defaultConfig } return cfg } } return t.defaultConfig } func (t *toolReductionMiddleware) WrapInvokableToolCall(_ context.Context, endpoint adk.InvokableToolCallEndpoint, tCtx *adk.ToolContext) (adk.InvokableToolCallEndpoint, error) { cfg := t.getToolConfig(tCtx.Name, sceneTruncation) if cfg == nil || cfg.TruncHandler == nil { return endpoint, nil } return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { output, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { return "", err } detail := &ToolDetail{ ToolContext: tCtx, ToolArgument: &schema.ToolArgument{ Text: argumentsInJSON, }, ToolResult: &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: output}, }, }, } truncResult, err := cfg.TruncHandler(ctx, detail) if err != nil { return "", err } if !truncResult.NeedTrunc { return output, nil } if truncResult.NeedOffload { if cfg.Backend == nil { return "", fmt.Errorf("truncation: no backend for offload") } if err = cfg.Backend.Write(ctx, &filesystem.WriteRequest{ FilePath: truncResult.OffloadFilePath, Content: truncResult.OffloadContent, }); err != nil { return "", err } } return truncResult.ToolResult.Parts[0].Text, nil }, nil } func (t *toolReductionMiddleware) WrapStreamableToolCall(_ context.Context, endpoint adk.StreamableToolCallEndpoint, tCtx *adk.ToolContext) (adk.StreamableToolCallEndpoint, error) { cfg := t.getToolConfig(tCtx.Name, sceneTruncation) if cfg == nil || cfg.TruncHandler == nil { return endpoint, nil } return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { output, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { return nil, err } var chunks []string readers := output.Copy(2) output = readers[0] origResp := readers[1] defer output.Close() for { var recvErr error chunk, recvErr := output.Recv() if recvErr != nil { if recvErr != io.EOF { return origResp, nil } break } chunks = append(chunks, chunk) } result := strings.Join(chunks, "") detail := &ToolDetail{ ToolContext: tCtx, ToolArgument: &schema.ToolArgument{ Text: argumentsInJSON, }, ToolResult: &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: result}, }, }, } truncResult, err := cfg.TruncHandler(ctx, detail) if err != nil { return nil, err } if !truncResult.NeedTrunc { return origResp, nil } origResp.Close() // close err resp when not using it if truncResult.NeedOffload { if cfg.Backend == nil { return nil, fmt.Errorf("truncation: no backend for offload") } if err = cfg.Backend.Write(ctx, &filesystem.WriteRequest{ FilePath: truncResult.OffloadFilePath, Content: truncResult.OffloadContent, }); err != nil { return nil, err } } return schema.StreamReaderFromArray([]string{truncResult.ToolResult.Parts[0].Text}), nil }, nil } func (t *toolReductionMiddleware) BeforeModelRewriteState(ctx context.Context, state *adk.ChatModelAgentState, mc *adk.ModelContext) ( context.Context, *adk.ChatModelAgentState, error) { var ( err error estimatedTokens int64 ) // init msg tokens estimatedTokens, err = t.config.TokenCounter(ctx, state.Messages, mc.Tools) if err != nil { return ctx, state, err } if estimatedTokens < t.config.MaxTokensForClear { return ctx, state, nil } // calc range var ( start = 0 end = len(state.Messages) ) for ; start < len(state.Messages); start++ { msg := state.Messages[start] if msg.Role == schema.Assistant && !getMsgOffloadedFlag(msg) { break } } retention := t.config.ClearRetentionSuffixLimit for ; retention > 0 && end > 0; end-- { msg := state.Messages[end-1] if msg.Role == schema.Assistant && len(msg.ToolCalls) > 0 { retention-- if retention == 0 { end-- break } } } if start >= end { return ctx, state, nil } // recursively handle tcMsgIndex := start for tcMsgIndex < end { tcMsg := state.Messages[tcMsgIndex] if tcMsg.Role == schema.Assistant && len(tcMsg.ToolCalls) > 0 { trMsgEnd := tcMsgIndex + 1 + len(tcMsg.ToolCalls) if trMsgEnd > len(state.Messages) { trMsgEnd = len(state.Messages) } j := tcMsgIndex for tcIndex, toolCall := range tcMsg.ToolCalls { j++ if j >= end { break } resultMsg := state.Messages[j] if resultMsg.Role != schema.Tool { // unexpected break } cfg := t.getToolConfig(toolCall.Function.Name, sceneClear) if cfg == nil || cfg.ClearHandler == nil { continue } toolResult, fromContent, toolResultErr := toolResultFromMessage(resultMsg) if toolResultErr != nil { return ctx, state, toolResultErr } td := &ToolDetail{ ToolContext: &adk.ToolContext{ Name: toolCall.Function.Name, CallID: toolCall.ID, }, ToolArgument: &schema.ToolArgument{ Text: toolCall.Function.Arguments, }, ToolResult: toolResult, } offloadInfo, offloadErr := cfg.ClearHandler(ctx, td) if offloadErr != nil { return ctx, state, offloadErr } if !offloadInfo.NeedClear { continue } if offloadInfo.NeedOffload { if cfg.Backend == nil { return ctx, state, fmt.Errorf("clear: no backend for offload") } writeErr := cfg.Backend.Write(ctx, &filesystem.WriteRequest{ FilePath: offloadInfo.OffloadFilePath, Content: offloadInfo.OffloadContent, }) if writeErr != nil { return ctx, state, writeErr } } tcMsg.ToolCalls[tcIndex].Function.Arguments = offloadInfo.ToolArgument.Text if fromContent { if len(offloadInfo.ToolResult.Parts) > 0 { resultMsg.Content = offloadInfo.ToolResult.Parts[0].Text } } else { var convErr error resultMsg.UserInputMultiContent, convErr = offloadInfo.ToolResult.ToMessageInputParts() if convErr != nil { return ctx, state, convErr } } } // set dedup flag setMsgOffloadedFlag(tcMsg) } tcMsgIndex++ } if t.config.ClearPostProcess != nil { ctx = t.config.ClearPostProcess(ctx, state) } return ctx, state, nil } // defaultTokenCounter estimates tokens, which treats one token as ~4 characters of text for common English text. // github.com/tiktoken-go/tokenizer is highly recommended to replace it. func defaultTokenCounter(_ context.Context, msgs []*schema.Message, tools []*schema.ToolInfo) (int64, error) { var tokens int64 for _, msg := range msgs { if msg == nil { continue } if cached, ok := getMsgCachedToken(msg); ok { tokens += cached continue } var sb strings.Builder sb.WriteString(string(msg.Role)) sb.WriteString("\n") sb.WriteString(msg.ReasoningContent) sb.WriteString("\n") sb.WriteString(msg.Content) sb.WriteString("\n") if msg.Role == schema.Assistant && len(msg.ToolCalls) > 0 { for _, tc := range msg.ToolCalls { sb.WriteString(tc.Function.Name) sb.WriteString("\n") sb.WriteString(tc.Function.Arguments) } } n := int64(len(sb.String()) / 4) setMsgCachedToken(msg, n) tokens += n } for _, tl := range tools { tl_ := *tl tl_.Extra = nil text, err := sonic.MarshalString(tl_) if err != nil { return 0, fmt.Errorf("failed to marshal tool info: %w", err) } tokens += int64(len(text) / 4) } return tokens, nil } func defaultTruncHandler(rootDir string, truncMaxLength int) func(ctx context.Context, detail *ToolDetail) (truncResult *TruncResult, err error) { return func(ctx context.Context, detail *ToolDetail) (offloadInfo *TruncResult, err error) { resultText := detail.ToolResult.Parts[0].Text if len(resultText) <= truncMaxLength { return &TruncResult{NeedTrunc: false}, nil } filePath := filepath.Join(rootDir, "trunc", detail.ToolContext.CallID) previewSize := truncMaxLength / 2 truncNotify, err := pyfmt.Fmt(getTruncFmt(), map[string]any{ "original_size": len(resultText), "file_path": filePath, "preview_size": previewSize, "preview_first": resultText[:previewSize], "preview_last": resultText[len(resultText)-previewSize:], }) if err != nil { return nil, err } return &TruncResult{ ToolResult: &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: resultText[:truncMaxLength] + truncNotify}, }, }, NeedTrunc: true, NeedOffload: true, OffloadFilePath: filePath, OffloadContent: resultText, }, nil } } func defaultClearHandler(rootDir string, needOffload bool, readFileToolName string) func(ctx context.Context, detail *ToolDetail) (*ClearResult, error) { return func(ctx context.Context, detail *ToolDetail) (clearResult *ClearResult, err error) { if len(detail.ToolResult.Parts) == 0 { return &ClearResult{NeedClear: false}, nil } for _, part := range detail.ToolResult.Parts { if part.Type != schema.ToolPartTypeText { // brutal judge return nil, fmt.Errorf("default offload currently not support multimodal content type=%v", part.Type) } } fileName := detail.ToolContext.CallID if fileName == "" { fileName = uuid.NewString() } var nResult string if needOffload { filePath := filepath.Join(rootDir, "clear", fileName) nResult, err = pyfmt.Fmt(getClearWithOffloadingFmt(), map[string]any{ "file_path": filePath, "read_tool_name": readFileToolName, }) if err != nil { return nil, err } clearResult = &ClearResult{ ToolArgument: detail.ToolArgument, NeedClear: true, NeedOffload: true, OffloadFilePath: filePath, OffloadContent: detail.ToolResult.Parts[0].Text, } } else { nResult = getClearWithoutOffloadingFmt() clearResult = &ClearResult{ ToolArgument: detail.ToolArgument, NeedClear: true, NeedOffload: false, } } clearResult.ToolResult = &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: nResult}, }, } return clearResult, nil } } func getMsgOffloadedFlag(msg *schema.Message) (offloaded bool) { if msg.Extra == nil { return false } v, ok := msg.Extra[msgReducedFlag].(bool) if !ok { return false } return v } func setMsgOffloadedFlag(msg *schema.Message) { if msg.Extra == nil { msg.Extra = make(map[string]any) } msg.Extra[msgReducedFlag] = true } func getMsgCachedToken(msg *schema.Message) (int64, bool) { if msg.Extra == nil { return 0, false } tokens, ok := msg.Extra[msgReducedTokens].(int64) return tokens, ok } func setMsgCachedToken(msg *schema.Message, tokens int64) { if msg.Extra == nil { msg.Extra = make(map[string]any) } msg.Extra[msgReducedTokens] = tokens } func toolResultFromMessage(msg *schema.Message) (result *schema.ToolResult, fromContent bool, err error) { if msg.Role != schema.Tool { return nil, false, fmt.Errorf("message role %s is not a tool", msg.Role) } if msg.Content != "" { return &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: msg.Content}}}, true, nil } result = &schema.ToolResult{Parts: make([]schema.ToolOutputPart, 0, len(msg.UserInputMultiContent))} for _, part := range msg.UserInputMultiContent { top, convErr := convMessageInputPartToToolOutputPart(part) if convErr != nil { return nil, false, convErr } result.Parts = append(result.Parts, top) } return result, false, nil } func convMessageInputPartToToolOutputPart(msgPart schema.MessageInputPart) (schema.ToolOutputPart, error) { switch msgPart.Type { case schema.ChatMessagePartTypeText: return schema.ToolOutputPart{ Type: schema.ToolPartTypeText, Text: msgPart.Text, }, nil case schema.ChatMessagePartTypeImageURL: return schema.ToolOutputPart{ Type: schema.ToolPartTypeImage, Image: &schema.ToolOutputImage{ MessagePartCommon: msgPart.Image.MessagePartCommon, }, }, nil case schema.ChatMessagePartTypeAudioURL: return schema.ToolOutputPart{ Type: schema.ToolPartTypeAudio, Audio: &schema.ToolOutputAudio{ MessagePartCommon: msgPart.Audio.MessagePartCommon, }, }, nil case schema.ChatMessagePartTypeVideoURL: return schema.ToolOutputPart{ Type: schema.ToolPartTypeVideo, Video: &schema.ToolOutputVideo{ MessagePartCommon: msgPart.Video.MessagePartCommon, }, }, nil case schema.ChatMessagePartTypeFileURL: return schema.ToolOutputPart{ Type: schema.ToolPartTypeFile, File: &schema.ToolOutputFile{ MessagePartCommon: msgPart.File.MessagePartCommon, }, }, nil default: return schema.ToolOutputPart{}, fmt.Errorf("unknown msg part type: %v", msgPart.Type) } } ================================================ FILE: adk/middlewares/reduction/reduction_test.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package reduction import ( "context" "encoding/json" "fmt" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/components/tool/utils" "github.com/cloudwego/eino/schema" ) func TestReductionMiddlewareTrunc(t *testing.T) { ctx := context.Background() it := mockInvokableTool() st := mockStreamableTool() t.Run("test invokable max length trunc", func(t *testing.T) { tCtx := &adk.ToolContext{ Name: "mock_invokable_tool", CallID: "12345", } backend := filesystem.NewInMemoryBackend() config := &Config{ Backend: backend, ToolConfig: map[string]*ToolReductionConfig{ "mock_invokable_tool": { Backend: backend, SkipTruncation: false, TruncHandler: defaultTruncHandler("/tmp", 70), }, }, } mw, err := New(ctx, config) assert.NoError(t, err) exp := "hello worldhello worldhello worldhello worldhello worldhello worldhell\nOutput too large (199). Full output saved to: /tmp/trunc/12345\nPreview (first 35):\nhello worldhello worldhello worldhe\n\nPreview (last 35):\nldhello worldhello worldhello world\n\n" edp, err := mw.WrapInvokableToolCall(ctx, it.InvokableRun, tCtx) assert.NoError(t, err) resp, err := edp(ctx, `{"value":"asd"}`) assert.NoError(t, err) assert.Equal(t, exp, resp) content, err := backend.Read(ctx, &filesystem.ReadRequest{FilePath: "/tmp/trunc/12345"}) assert.NoError(t, err) expOrigContent := `hello worldhello worldhello worldhello worldhello worldhello worldhello worldhello worldhello worldhello world hello worldhello worldhello worldhello worldhello worldhello worldhello worldhello world` assert.Equal(t, expOrigContent, content.Content) }) t.Run("test streamable line and max length trunc", func(t *testing.T) { tCtx := &adk.ToolContext{ Name: "mock_streamable_tool", CallID: "54321", } backend := filesystem.NewInMemoryBackend() config := &Config{ SkipTruncation: true, ToolConfig: map[string]*ToolReductionConfig{ "mock_streamable_tool": { Backend: backend, SkipTruncation: false, TruncHandler: defaultTruncHandler("/tmp", 70), }, }, } mw, err := New(ctx, config) assert.NoError(t, err) exp := "hello worldhello worldhello worldhello worldhello worldhello worldhell\nOutput too large (199). Full output saved to: /tmp/trunc/54321\nPreview (first 35):\nhello worldhello worldhello worldhe\n\nPreview (last 35):\nldhello worldhello worldhello world\n\n" edp, err := mw.WrapStreamableToolCall(ctx, st.StreamableRun, tCtx) assert.NoError(t, err) resp, err := edp(ctx, `{"value":"asd"}`) assert.NoError(t, err) s, err := resp.Recv() assert.NoError(t, err) resp.Close() assert.Equal(t, exp, s) content, err := backend.Read(ctx, &filesystem.ReadRequest{FilePath: "/tmp/trunc/54321"}) assert.NoError(t, err) expOrigContent := `hello worldhello worldhello worldhello worldhello worldhello worldhello worldhello worldhello worldhello world hello worldhello worldhello worldhello worldhello worldhello worldhello worldhello world` assert.Equal(t, expOrigContent, content.Content) }) } func TestReductionMiddlewareClear(t *testing.T) { ctx := context.Background() it := mockInvokableTool() st := mockStreamableTool() tools := []tool.BaseTool{it, st} var toolsInfo []*schema.ToolInfo for _, bt := range tools { ti, _ := bt.Info(ctx) toolsInfo = append(toolsInfo, ti) } type OffloadContent struct { Arguments map[string]string `json:"arguments"` Result string `json:"result"` } t.Run("test default clear", func(t *testing.T) { backend := filesystem.NewInMemoryBackend() config := &Config{ SkipTruncation: true, TokenCounter: defaultTokenCounter, MaxTokensForClear: 20, ClearRetentionSuffixLimit: 0, ToolConfig: map[string]*ToolReductionConfig{ "get_weather": { Backend: backend, SkipClear: false, ClearHandler: defaultClearHandler("/tmp", true, "read_file"), }, }, } mw, err := New(ctx, config) assert.NoError(t, err) _, s, err := mw.BeforeModelRewriteState(ctx, &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.SystemMessage("you are a helpful assistant"), schema.UserMessage("If it's warmer than 20°C in London, set the thermostat to 20°C, otherwise set it to 18°C."), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), }, }, &adk.ModelContext{ Tools: toolsInfo, }) assert.NoError(t, err) assert.Equal(t, []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[2].ToolCalls) assert.Equal(t, []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[4].ToolCalls) assert.Equal(t, "Tool result saved to: /tmp/clear/call_987654321\nUse read_file to view", s.Messages[3].Content) fileContent, err := backend.Read(ctx, &filesystem.ReadRequest{ FilePath: "/tmp/clear/call_987654321", }) assert.NoError(t, err) fileContentStr := strings.TrimPrefix(strings.TrimSpace(fileContent.Content), "1\t") assert.Equal(t, "Sunny", fileContentStr) }) t.Run("test default clear without offloading", func(t *testing.T) { config := &Config{ SkipTruncation: true, TokenCounter: defaultTokenCounter, MaxTokensForClear: 20, ClearRetentionSuffixLimit: 0, ToolConfig: map[string]*ToolReductionConfig{ "get_weather": { SkipClear: false, ClearHandler: defaultClearHandler("", false, ""), }, }, } mw, err := New(ctx, config) assert.NoError(t, err) _, s, err := mw.BeforeModelRewriteState(ctx, &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.SystemMessage("you are a helpful assistant"), schema.UserMessage("If it's warmer than 20°C in London, set the thermostat to 20°C, otherwise set it to 18°C."), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), }, }, &adk.ModelContext{ Tools: toolsInfo, }) assert.NoError(t, err) assert.Equal(t, []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[2].ToolCalls) assert.Equal(t, []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[4].ToolCalls) assert.Equal(t, "[Old tool result content cleared]", s.Messages[3].Content) }) t.Run("test clear", func(t *testing.T) { backend := filesystem.NewInMemoryBackend() handler := func(ctx context.Context, detail *ToolDetail) (*ClearResult, error) { arguments := make(map[string]string) if err := json.Unmarshal([]byte(detail.ToolArgument.Text), &arguments); err != nil { return nil, err } offloadContent := &OffloadContent{ Arguments: arguments, Result: detail.ToolResult.Parts[0].Text, } replacedArguments := make(map[string]string, len(arguments)) filePath := fmt.Sprintf("/tmp/%s", detail.ToolContext.CallID) for k := range arguments { replacedArguments[k] = "argument offloaded" } return &ClearResult{ ToolArgument: &schema.ToolArgument{Text: toJson(replacedArguments)}, ToolResult: &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "result offloaded, retrieve it from " + filePath}, }, }, NeedClear: true, NeedOffload: true, OffloadFilePath: filePath, OffloadContent: toJson(offloadContent), }, nil } config := &Config{ SkipTruncation: true, TokenCounter: defaultTokenCounter, MaxTokensForClear: 20, ClearRetentionSuffixLimit: 1, ToolConfig: map[string]*ToolReductionConfig{ "get_weather": { Backend: backend, SkipClear: false, ClearHandler: handler, }, }, } mw, err := New(ctx, config) assert.NoError(t, err) _, s, err := mw.BeforeModelRewriteState(ctx, &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.SystemMessage("you are a helpful assistant"), schema.UserMessage("If it's warmer than 20°C in London, set the thermostat to 20°C, otherwise set it to 18°C."), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), }, }, &adk.ModelContext{ Tools: toolsInfo, }) assert.NoError(t, err) assert.Equal(t, []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location":"argument offloaded","unit":"argument offloaded"}`}, }, }, s.Messages[2].ToolCalls) assert.Equal(t, []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[4].ToolCalls) assert.Equal(t, "result offloaded, retrieve it from /tmp/call_987654321", s.Messages[3].Content) fileContent, err := backend.Read(ctx, &filesystem.ReadRequest{ FilePath: "/tmp/call_987654321", }) assert.NoError(t, err) fileContentStr := strings.TrimPrefix(strings.TrimSpace(fileContent.Content), "1\t") oc := &OffloadContent{} err = json.Unmarshal([]byte(fileContentStr), oc) assert.NoError(t, err) assert.Equal(t, &OffloadContent{ Arguments: map[string]string{ "location": "London, UK", "unit": "c", }, Result: "Sunny", }, oc) }) t.Run("test skip handled ones", func(t *testing.T) { backend := filesystem.NewInMemoryBackend() config := &Config{ SkipTruncation: true, TokenCounter: defaultTokenCounter, MaxTokensForClear: 20, ClearRetentionSuffixLimit: 0, ToolConfig: map[string]*ToolReductionConfig{ "get_weather": { Backend: backend, SkipClear: false, ClearHandler: defaultClearHandler("/tmp", true, "read_file"), }, }, } mw, err := New(ctx, config) assert.NoError(t, err) msgs := []adk.Message{ schema.SystemMessage("you are a helpful assistant"), schema.UserMessage("If it's warmer than 20°C in London, set the thermostat to 20°C, otherwise set it to 18°C."), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_123456789"), } _, s, err := mw.BeforeModelRewriteState(ctx, &adk.ChatModelAgentState{Messages: msgs}, &adk.ModelContext{Tools: toolsInfo}) assert.NoError(t, err) assert.Equal(t, []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[2].ToolCalls) assert.NotNil(t, msgs[2].Extra[msgReducedFlag]) assert.Equal(t, []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[4].ToolCalls) assert.Equal(t, "Tool result saved to: /tmp/clear/call_987654321\nUse read_file to view", s.Messages[3].Content) fileContent, err := backend.Read(ctx, &filesystem.ReadRequest{ FilePath: "/tmp/clear/call_987654321", }) assert.NoError(t, err) fileContentStr := strings.TrimPrefix(strings.TrimSpace(fileContent.Content), "1\t") assert.Equal(t, "Sunny", fileContentStr) msgs = append(msgs, []*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_8877665544", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }), schema.ToolMessage("Sunny", "call_8877665544"), }...) _, s, err = mw.BeforeModelRewriteState(ctx, &adk.ChatModelAgentState{Messages: msgs}, &adk.ModelContext{Tools: toolsInfo}) assert.NoError(t, err) assert.Equal(t, []schema.ToolCall{ { ID: "call_987654321", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[2].ToolCalls) assert.NotNil(t, msgs[2].Extra[msgReducedFlag]) assert.Equal(t, []schema.ToolCall{ { ID: "call_123456789", Type: "function", Function: schema.FunctionCall{Name: "get_weather", Arguments: `{"location": "London, UK", "unit": "c"}`}, }, }, s.Messages[4].ToolCalls) assert.NotNil(t, msgs[4].Extra[msgReducedFlag]) assert.Equal(t, "Tool result saved to: /tmp/clear/call_987654321\nUse read_file to view", s.Messages[3].Content) assert.Equal(t, "Tool result saved to: /tmp/clear/call_123456789\nUse read_file to view", s.Messages[5].Content) }) } func TestDefaultOffloadHandler(t *testing.T) { ctx := context.Background() detail := &ToolDetail{ ToolContext: &adk.ToolContext{ Name: "mock_name", CallID: "mock_call_id_12345", }, ToolArgument: &schema.ToolArgument{Text: "anything"}, ToolResult: &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "hello"}}}, } fn := defaultClearHandler("/tmp", true, "read_file") info, err := fn(ctx, detail) assert.NoError(t, err) assert.Equal(t, &ClearResult{ ToolArgument: &schema.ToolArgument{Text: "anything"}, ToolResult: &schema.ToolResult{Parts: []schema.ToolOutputPart{ { Type: schema.ToolPartTypeText, Text: "Tool result saved to: /tmp/clear/mock_call_id_12345\nUse read_file to view", }, }}, NeedClear: true, NeedOffload: true, OffloadFilePath: "/tmp/clear/mock_call_id_12345", OffloadContent: "hello", }, info) } func mockInvokableTool() tool.InvokableTool { type ContentContainer struct { Value string `json:"value"` } s1 := strings.Repeat("hello world", 10) + "\n" s2 := strings.Repeat("hello world", 8) s3 := s1 + s2 t, _ := utils.InferTool("mock_invokable_tool", "test desc", func(ctx context.Context, input *ContentContainer) (output string, err error) { return s3, nil }) return t } func mockStreamableTool() tool.StreamableTool { type ContentContainer struct { Value string `json:"value"` } s1 := strings.Repeat("hello world", 10) + "\n" s2 := strings.Repeat("hello world", 8) s3 := s1 + s2 t, _ := utils.InferStreamTool("mock_streamable_tool", "test desc", func(ctx context.Context, input ContentContainer) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](11) for _, part := range splitStrings(s3, 10) { sw.Send(part, nil) } sw.Close() return sr, nil }) return t } func splitStrings(s string, n int) []string { if n <= 0 { n = 1 } if n == 1 { return []string{s} } if len(s) <= n { parts := make([]string, n) for i := 0; i < len(s); i++ { parts[i] = string(s[i]) } return parts } baseLen := len(s) / n extra := len(s) % n parts := make([]string, 0, n) start := 0 for i := 0; i < n; i++ { end := start + baseLen if i < extra { end++ } parts = append(parts, s[start:end]) start = end } return parts } func toJson(v any) string { b, _ := json.Marshal(v) return string(b) } func TestToolResultFromMessage(t *testing.T) { t.Run("test from content", func(t *testing.T) { msg := schema.ToolMessage("test content", "call_123") result, fromContent, err := toolResultFromMessage(msg) assert.NoError(t, err) assert.True(t, fromContent) assert.NotNil(t, result) assert.Len(t, result.Parts, 1) assert.Equal(t, schema.ToolPartTypeText, result.Parts[0].Type) assert.Equal(t, "test content", result.Parts[0].Text) }) t.Run("test from user input multi content", func(t *testing.T) { msg := schema.ToolMessage("", "call_456") msg.UserInputMultiContent = []schema.MessageInputPart{ { Type: schema.ChatMessagePartTypeText, Text: "test text", }, } result, fromContent, err := toolResultFromMessage(msg) assert.NoError(t, err) assert.False(t, fromContent) assert.NotNil(t, result) assert.Len(t, result.Parts, 1) assert.Equal(t, schema.ToolPartTypeText, result.Parts[0].Type) assert.Equal(t, "test text", result.Parts[0].Text) }) t.Run("test invalid role", func(t *testing.T) { msg := schema.UserMessage("test user message") _, _, err := toolResultFromMessage(msg) assert.Error(t, err) assert.Contains(t, err.Error(), "message role") }) } func TestConvMessageInputPartToToolOutputPart(t *testing.T) { t.Run("test text type", func(t *testing.T) { part := schema.MessageInputPart{ Type: schema.ChatMessagePartTypeText, Text: "test text", } result, err := convMessageInputPartToToolOutputPart(part) assert.NoError(t, err) assert.Equal(t, schema.ToolPartTypeText, result.Type) assert.Equal(t, "test text", result.Text) }) t.Run("test image url type", func(t *testing.T) { part := schema.MessageInputPart{ Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{}, } result, err := convMessageInputPartToToolOutputPart(part) assert.NoError(t, err) assert.Equal(t, schema.ToolPartTypeImage, result.Type) assert.NotNil(t, result.Image) }) t.Run("test audio url type", func(t *testing.T) { part := schema.MessageInputPart{ Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageInputAudio{}, } result, err := convMessageInputPartToToolOutputPart(part) assert.NoError(t, err) assert.Equal(t, schema.ToolPartTypeAudio, result.Type) assert.NotNil(t, result.Audio) }) t.Run("test video url type", func(t *testing.T) { part := schema.MessageInputPart{ Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{}, } result, err := convMessageInputPartToToolOutputPart(part) assert.NoError(t, err) assert.Equal(t, schema.ToolPartTypeVideo, result.Type) assert.NotNil(t, result.Video) }) t.Run("test file url type", func(t *testing.T) { part := schema.MessageInputPart{ Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{}, } result, err := convMessageInputPartToToolOutputPart(part) assert.NoError(t, err) assert.Equal(t, schema.ToolPartTypeFile, result.Type) assert.NotNil(t, result.File) }) t.Run("test unknown type", func(t *testing.T) { part := schema.MessageInputPart{ Type: "unknown_type", } _, err := convMessageInputPartToToolOutputPart(part) assert.Error(t, err) assert.Contains(t, err.Error(), "unknown msg part type") }) } func TestGetSetMsgOffloadedFlag(t *testing.T) { t.Run("test get offloaded flag - not set", func(t *testing.T) { msg := schema.UserMessage("test") assert.False(t, getMsgOffloadedFlag(msg)) }) t.Run("test get offloaded flag - set", func(t *testing.T) { msg := schema.UserMessage("test") setMsgOffloadedFlag(msg) assert.True(t, getMsgOffloadedFlag(msg)) }) t.Run("test set offloaded flag - nil extra", func(t *testing.T) { msg := schema.UserMessage("test") setMsgOffloadedFlag(msg) assert.True(t, getMsgOffloadedFlag(msg)) }) t.Run("test set offloaded flag - existing extra", func(t *testing.T) { msg := schema.UserMessage("test") msg.Extra = map[string]any{"existing": "value"} setMsgOffloadedFlag(msg) assert.True(t, getMsgOffloadedFlag(msg)) assert.Equal(t, "value", msg.Extra["existing"]) }) } func TestGetSetMsgCachedToken(t *testing.T) { t.Run("test get cached token - not set", func(t *testing.T) { msg := schema.UserMessage("test") tokens, ok := getMsgCachedToken(msg) assert.False(t, ok) assert.Equal(t, int64(0), tokens) }) t.Run("test get cached token - set", func(t *testing.T) { msg := schema.UserMessage("test") setMsgCachedToken(msg, 100) tokens, ok := getMsgCachedToken(msg) assert.True(t, ok) assert.Equal(t, int64(100), tokens) }) t.Run("test set cached token - nil extra", func(t *testing.T) { msg := schema.UserMessage("test") setMsgCachedToken(msg, 200) tokens, ok := getMsgCachedToken(msg) assert.True(t, ok) assert.Equal(t, int64(200), tokens) }) t.Run("test set cached token - existing extra", func(t *testing.T) { msg := schema.UserMessage("test") msg.Extra = map[string]any{"existing": "value"} setMsgCachedToken(msg, 300) tokens, ok := getMsgCachedToken(msg) assert.True(t, ok) assert.Equal(t, int64(300), tokens) assert.Equal(t, "value", msg.Extra["existing"]) }) } func TestNewErrors(t *testing.T) { ctx := context.Background() t.Run("test nil config", func(t *testing.T) { _, err := New(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "config must not be nil") }) t.Run("test no backend when not skipping truncation", func(t *testing.T) { config := &Config{ Backend: nil, SkipTruncation: false, } _, err := New(ctx, config) assert.Error(t, err) assert.Contains(t, err.Error(), "backend must be set") }) } func TestGetToolConfig(t *testing.T) { ctx := context.Background() backend := filesystem.NewInMemoryBackend() t.Run("test no tool config", func(t *testing.T) { config := &Config{ Backend: backend, SkipTruncation: true, SkipClear: true, } mw, err := New(ctx, config) assert.NoError(t, err) trmw, ok := mw.(*toolReductionMiddleware) assert.True(t, ok) cfg := trmw.getToolConfig("non_existent_tool", sceneTruncation) assert.NotNil(t, cfg) }) t.Run("test with tool config", func(t *testing.T) { config := &Config{ Backend: backend, SkipTruncation: true, SkipClear: true, ToolConfig: map[string]*ToolReductionConfig{ "test_tool": { SkipTruncation: true, SkipClear: true, }, }, } mw, err := New(ctx, config) assert.NoError(t, err) trmw, ok := mw.(*toolReductionMiddleware) assert.True(t, ok) cfg := trmw.getToolConfig("test_tool", sceneTruncation) assert.NotNil(t, cfg) assert.True(t, cfg.SkipTruncation) }) t.Run("test with tool config needing default handler", func(t *testing.T) { config := &Config{ Backend: backend, SkipTruncation: false, ToolConfig: map[string]*ToolReductionConfig{ "test_tool": { SkipTruncation: false, }, }, } mw, err := New(ctx, config) assert.NoError(t, err) trmw, ok := mw.(*toolReductionMiddleware) assert.True(t, ok) cfg := trmw.getToolConfig("test_tool", sceneTruncation) assert.NotNil(t, cfg) assert.NotNil(t, cfg.TruncHandler) }) } func TestCopyAndFillDefaults(t *testing.T) { t.Run("test empty config", func(t *testing.T) { cfg := &Config{} result, err := cfg.copyAndFillDefaults() assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, "/tmp", result.RootDir) assert.Equal(t, "read_file", result.ReadFileToolName) assert.Equal(t, 50000, result.MaxLengthForTrunc) assert.Equal(t, 1, result.ClearRetentionSuffixLimit) assert.NotNil(t, result.TokenCounter) }) t.Run("test with tool config", func(t *testing.T) { cfg := &Config{ ToolConfig: map[string]*ToolReductionConfig{ "test_tool": { SkipTruncation: true, }, }, } result, err := cfg.copyAndFillDefaults() assert.NoError(t, err) assert.NotNil(t, result.ToolConfig) assert.True(t, result.ToolConfig["test_tool"].SkipTruncation) }) } func TestDefaultTokenCounter(t *testing.T) { ctx := context.Background() t.Run("test with nil messages", func(t *testing.T) { msgs := []*schema.Message{nil} tokens, err := defaultTokenCounter(ctx, msgs, nil) assert.NoError(t, err) assert.GreaterOrEqual(t, tokens, int64(0)) }) t.Run("test with tool info", func(t *testing.T) { toolInfo := &schema.ToolInfo{ Name: "test_tool", Desc: "test description", } tokens, err := defaultTokenCounter(ctx, nil, []*schema.ToolInfo{toolInfo}) assert.NoError(t, err) assert.GreaterOrEqual(t, tokens, int64(0)) }) } func TestDefaultClearHandler(t *testing.T) { ctx := context.Background() t.Run("test empty parts", func(t *testing.T) { handler := defaultClearHandler("/tmp", true, "read_file") detail := &ToolDetail{ ToolContext: &adk.ToolContext{ CallID: "test_call", }, ToolResult: &schema.ToolResult{Parts: []schema.ToolOutputPart{}}, } result, err := handler(ctx, detail) assert.NoError(t, err) assert.False(t, result.NeedClear) }) t.Run("test multimodal content", func(t *testing.T) { handler := defaultClearHandler("/tmp", true, "read_file") detail := &ToolDetail{ ToolContext: &adk.ToolContext{ CallID: "test_call", }, ToolResult: &schema.ToolResult{ Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeImage}}, }, } _, err := handler(ctx, detail) assert.Error(t, err) assert.Contains(t, err.Error(), "not support multimodal") }) t.Run("test no call id", func(t *testing.T) { handler := defaultClearHandler("/tmp", true, "read_file") detail := &ToolDetail{ ToolContext: &adk.ToolContext{}, ToolResult: &schema.ToolResult{ Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "test"}}, }, } result, err := handler(ctx, detail) assert.NoError(t, err) assert.True(t, result.NeedClear) assert.NotEmpty(t, result.OffloadFilePath) }) } ================================================ FILE: adk/middlewares/skill/filesystem_backend.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package skill import ( "context" "fmt" "path/filepath" "strings" "gopkg.in/yaml.v3" "github.com/cloudwego/eino/adk/filesystem" ) const skillFileName = "SKILL.md" type filesystemBackend struct { backend filesystem.Backend baseDir string } // BackendFromFilesystemConfig contains configuration for NewBackendFromFilesystem. type BackendFromFilesystemConfig struct { // Backend is the filesystem.Backend implementation used for file operations. Backend filesystem.Backend // BaseDir is the base directory where skill directories are located. // Each skill should be in a subdirectory containing a SKILL.md file. BaseDir string } // NewBackendFromFilesystem creates a new Backend implementation that reads skills from a filesystem. // It searches for SKILL.md files in immediate subdirectories of the configured BaseDir. // Only first-level subdirectories are scanned; deeply nested SKILL.md files are ignored. func NewBackendFromFilesystem(_ context.Context, config *BackendFromFilesystemConfig) (Backend, error) { if config == nil { return nil, fmt.Errorf("config is required") } if config.Backend == nil { return nil, fmt.Errorf("backend is required") } if config.BaseDir == "" { return nil, fmt.Errorf("baseDir is required") } return &filesystemBackend{ backend: config.Backend, baseDir: config.BaseDir, }, nil } func (b *filesystemBackend) List(ctx context.Context) ([]FrontMatter, error) { skills, err := b.list(ctx) if err != nil { return nil, fmt.Errorf("failed to list skills: %w", err) } matters := make([]FrontMatter, 0, len(skills)) for _, skill := range skills { matters = append(matters, skill.FrontMatter) } return matters, nil } func (b *filesystemBackend) Get(ctx context.Context, name string) (Skill, error) { skills, err := b.list(ctx) if err != nil { return Skill{}, fmt.Errorf("failed to list skills: %w", err) } for _, skill := range skills { if skill.Name == name { return skill, nil } } return Skill{}, fmt.Errorf("skill not found: %s", name) } func (b *filesystemBackend) list(ctx context.Context) ([]Skill, error) { var skills []Skill pattern := "*/" + skillFileName entries, err := b.backend.GlobInfo(ctx, &filesystem.GlobInfoRequest{ Pattern: pattern, Path: b.baseDir, }) if err != nil { return nil, fmt.Errorf("failed to glob skill files: %w", err) } for _, entry := range entries { filePath := entry.Path if !filepath.IsAbs(filePath) { filePath = filepath.Join(b.baseDir, filePath) } skill, loadErr := b.loadSkillFromFile(ctx, filePath) if loadErr != nil { return nil, fmt.Errorf("failed to load skill from %s: %w", filePath, loadErr) } skills = append(skills, skill) } return skills, nil } func (b *filesystemBackend) loadSkillFromFile(ctx context.Context, path string) (Skill, error) { fileContent, err := b.backend.Read(ctx, &filesystem.ReadRequest{ FilePath: path, }) if err != nil { return Skill{}, fmt.Errorf("failed to read file: %w", err) } data := stripLineNumbers(fileContent.Content) frontmatter, content, err := parseFrontmatter(data) if err != nil { return Skill{}, fmt.Errorf("failed to parse frontmatter: %w", err) } var fm FrontMatter if err = yaml.Unmarshal([]byte(frontmatter), &fm); err != nil { return Skill{}, fmt.Errorf("failed to unmarshal frontmatter: %w", err) } absDir := filepath.Dir(path) return Skill{ FrontMatter: fm, Content: strings.TrimSpace(content), BaseDirectory: absDir, }, nil } func stripLineNumbers(data string) string { lines := strings.Split(data, "\n") result := make([]string, 0, len(lines)) for _, line := range lines { idx := strings.Index(line, "\t") if idx != -1 { line = line[idx+1:] } result = append(result, line) } return strings.Join(result, "\n") } func parseFrontmatter(data string) (frontmatter string, content string, err error) { const delimiter = "---" data = strings.TrimSpace(data) if !strings.HasPrefix(data, delimiter) { return "", "", fmt.Errorf("file does not start with frontmatter delimiter") } rest := data[len(delimiter):] endIdx := strings.Index(rest, "\n"+delimiter) if endIdx == -1 { return "", "", fmt.Errorf("frontmatter closing delimiter not found") } frontmatter = strings.TrimSpace(rest[:endIdx]) content = rest[endIdx+len("\n"+delimiter):] if strings.HasPrefix(content, "\n") { content = content[1:] } return frontmatter, content, nil } ================================================ FILE: adk/middlewares/skill/filesystem_backend_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package skill import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudwego/eino/adk/filesystem" ) func TestNewBackendFromFilesystem(t *testing.T) { ctx := context.Background() t.Run("nil config returns error", func(t *testing.T) { backend, err := NewBackendFromFilesystem(ctx, nil) assert.Nil(t, backend) assert.Error(t, err) assert.Contains(t, err.Error(), "config is required") }) t.Run("nil backend returns error", func(t *testing.T) { backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ BaseDir: "/skills", }) assert.Nil(t, backend) assert.Error(t, err) assert.Contains(t, err.Error(), "backend is required") }) t.Run("empty baseDir returns error", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "", }) assert.Nil(t, backend) assert.Error(t, err) assert.Contains(t, err.Error(), "baseDir is required") }) t.Run("valid config succeeds", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) assert.NoError(t, err) assert.NotNil(t, backend) }) } func TestFilesystemBackend_List(t *testing.T) { ctx := context.Background() t.Run("empty directory returns empty list", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/.keep", Content: "", }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.NoError(t, err) assert.Empty(t, skills) }) t.Run("directory with no SKILL.md files returns empty list", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/subdir/other.txt", Content: "some content", }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.NoError(t, err) assert.Empty(t, skills) }) t.Run("files in root directory are ignored", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/SKILL.md", Content: `--- name: root-skill description: Root skill --- Content`, }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.NoError(t, err) assert.Empty(t, skills) }) t.Run("valid skill directory returns skill", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/my-skill/SKILL.md", Content: `--- name: pdf-processing description: Extract text and tables from PDF files, fill forms, merge documents. license: Apache-2.0 metadata: author: example-org version: "1.0" --- This is the skill content.`, }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.NoError(t, err) require.Len(t, skills, 1) assert.Equal(t, "pdf-processing", skills[0].Name) assert.Equal(t, "Extract text and tables from PDF files, fill forms, merge documents.", skills[0].Description) }) t.Run("multiple skill directories returns all skills", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/skill-1/SKILL.md", Content: `--- name: skill-1 description: First skill --- Content 1`, }) _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/skill-2/SKILL.md", Content: `--- name: skill-2 description: Second skill --- Content 2`, }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.NoError(t, err) assert.Len(t, skills, 2) names := []string{skills[0].Name, skills[1].Name} assert.Contains(t, names, "skill-1") assert.Contains(t, names, "skill-2") }) t.Run("invalid SKILL.md returns error", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/invalid-skill/SKILL.md", Content: `No frontmatter here`, }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.Error(t, err) assert.Nil(t, skills) assert.Contains(t, err.Error(), "failed to load skill") }) t.Run("non-existent baseDir returns empty list", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/path/that/does/not/exist", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.NoError(t, err) assert.Empty(t, skills) }) t.Run("deeply nested SKILL.md is ignored", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/valid-skill/SKILL.md", Content: `--- name: valid-skill description: Valid skill --- Content`, }) _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/deep/nested/SKILL.md", Content: `--- name: nested-skill description: Nested skill --- Content`, }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skills, err := backend.List(ctx) assert.NoError(t, err) assert.Len(t, skills, 1) assert.Equal(t, "valid-skill", skills[0].Name) }) } func TestFilesystemBackend_Get(t *testing.T) { ctx := context.Background() t.Run("skill not found returns error", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/.keep", Content: "", }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skill, err := backend.Get(ctx, "non-existent") assert.Error(t, err) assert.Empty(t, skill) assert.Contains(t, err.Error(), "skill not found") }) t.Run("existing skill is returned", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/test-skill/SKILL.md", Content: `--- name: test-skill description: Test skill description --- Test content here.`, }) backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skill, err := backend.Get(ctx, "test-skill") assert.NoError(t, err) assert.Equal(t, "test-skill", skill.Name) assert.Equal(t, "Test skill description", skill.Description) assert.Equal(t, "Test content here.", skill.Content) }) t.Run("get specific skill from multiple", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() for _, name := range []string{"alpha", "beta", "gamma"} { _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/" + name + "/SKILL.md", Content: `--- name: ` + name + ` description: Skill ` + name + ` --- Content for ` + name, }) } backend, err := NewBackendFromFilesystem(ctx, &BackendFromFilesystemConfig{ Backend: fsBackend, BaseDir: "/skills", }) require.NoError(t, err) skill, err := backend.Get(ctx, "beta") assert.NoError(t, err) assert.Equal(t, "beta", skill.Name) assert.Equal(t, "Skill beta", skill.Description) assert.Equal(t, "Content for beta", skill.Content) }) } func TestParseFrontmatter(t *testing.T) { t.Run("valid frontmatter", func(t *testing.T) { data := `--- name: test description: test description --- This is the content.` fm, content, err := parseFrontmatter(data) assert.NoError(t, err) assert.Equal(t, "name: test\ndescription: test description", fm) assert.Equal(t, "This is the content.", content) }) t.Run("frontmatter with multiline content", func(t *testing.T) { data := `--- name: test --- Line 1 Line 2 Line 3` fm, content, err := parseFrontmatter(data) assert.NoError(t, err) assert.Equal(t, "name: test", fm) assert.Equal(t, "Line 1\nLine 2\nLine 3", content) }) t.Run("frontmatter with leading/trailing whitespace", func(t *testing.T) { data := ` --- name: test --- Content ` fm, content, err := parseFrontmatter(data) assert.NoError(t, err) assert.Equal(t, "name: test", fm) assert.Equal(t, "Content", content) }) t.Run("missing opening delimiter returns error", func(t *testing.T) { data := `name: test --- Content` fm, content, err := parseFrontmatter(data) assert.Error(t, err) assert.Empty(t, fm) assert.Empty(t, content) assert.Contains(t, err.Error(), "does not start with frontmatter delimiter") }) t.Run("missing closing delimiter returns error", func(t *testing.T) { data := `--- name: test Content without closing` fm, content, err := parseFrontmatter(data) assert.Error(t, err) assert.Empty(t, fm) assert.Empty(t, content) assert.Contains(t, err.Error(), "closing delimiter not found") }) t.Run("empty frontmatter", func(t *testing.T) { data := `--- --- Content only` fm, content, err := parseFrontmatter(data) assert.NoError(t, err) assert.Empty(t, fm) assert.Equal(t, "Content only", content) }) t.Run("empty content", func(t *testing.T) { data := `--- name: test ---` fm, content, err := parseFrontmatter(data) assert.NoError(t, err) assert.Equal(t, "name: test", fm) assert.Empty(t, content) }) t.Run("content with --- inside", func(t *testing.T) { data := `--- name: test --- Content with --- in the middle` fm, content, err := parseFrontmatter(data) assert.NoError(t, err) assert.Equal(t, "name: test", fm) assert.Equal(t, "Content with --- in the middle", content) }) } func TestLoadSkillFromFile(t *testing.T) { ctx := context.Background() t.Run("valid skill file", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/SKILL.md", Content: `--- name: file-skill description: Skill from file --- File skill content.`, }) backend := &filesystemBackend{backend: fsBackend, baseDir: "/skills"} skill, err := backend.loadSkillFromFile(ctx, "/skills/SKILL.md") assert.NoError(t, err) assert.Equal(t, "file-skill", skill.Name) assert.Equal(t, "Skill from file", skill.Description) assert.Equal(t, "File skill content.", skill.Content) assert.Equal(t, "/skills", skill.BaseDirectory) }) t.Run("non-existent file returns error", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() backend := &filesystemBackend{backend: fsBackend, baseDir: "/tmp"} skill, err := backend.loadSkillFromFile(ctx, "/path/to/nonexistent/SKILL.md") assert.Error(t, err) assert.Empty(t, skill) assert.Contains(t, err.Error(), "failed to read file") }) t.Run("invalid yaml in frontmatter returns error", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/SKILL.md", Content: `--- name: [invalid yaml --- Content`, }) backend := &filesystemBackend{backend: fsBackend, baseDir: "/skills"} skill, err := backend.loadSkillFromFile(ctx, "/skills/SKILL.md") assert.Error(t, err) assert.Empty(t, skill) assert.Contains(t, err.Error(), "failed to unmarshal frontmatter") }) t.Run("content with extra whitespace is trimmed", func(t *testing.T) { fsBackend := filesystem.NewInMemoryBackend() _ = fsBackend.Write(ctx, &filesystem.WriteRequest{ FilePath: "/skills/SKILL.md", Content: `--- name: trimmed-skill description: desc --- Content with whitespace `, }) backend := &filesystemBackend{backend: fsBackend, baseDir: "/skills"} skill, err := backend.loadSkillFromFile(ctx, "/skills/SKILL.md") assert.NoError(t, err) assert.Equal(t, "Content with whitespace", skill.Content) }) } ================================================ FILE: adk/middlewares/skill/prompt.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package skill const ( systemPrompt = ` # Skills System **How to Use Skills (Progressive Disclosure):** Skills follow a **progressive disclosure** pattern - you see their name and description above, but only read full instructions when needed: 1. **Recognize when a skill applies**: Check if the user's task matches a skill's description 2. **Read the skill's full instructions**: Use the '{tool_name}' tool to load skill 3. **Follow the skill's instructions**: tool result contains step-by-step workflows, best practices, and examples 4. **Access supporting files**: Skills may include helper scripts, configs, or reference docs - use absolute paths **When to Use Skills:** - User's request matches a skill's domain (e.g., "research X" -> web-research skill) - You need specialized knowledge or structured workflows - A skill provides proven patterns for complex tasks **Executing Skill Scripts:** Skills may contain Python scripts or other executable files. Always use absolute paths. **Example Workflow:** User: "Can you research the latest developments in quantum computing?" 1. Check available skills -> See "web-research" skill 2. Call '{tool_name}' tool to read the full skill instructions 3. Follow the skill's research workflow (search -> organize -> synthesize) 4. Use any helper scripts with absolute paths Remember: Skills make you more capable and consistent. When in doubt, check if a skill exists for the task! ` systemPromptChinese = ` # Skill 系统 **如何使用 Skill(技能)(渐进式展示):** Skill 遵循**渐进式展示**模式 - 你可以在上方看到 Skill 的名称和描述,但只在需要时才阅读完整说明: 1. **识别 Skill 适用场景**:检查用户的任务是否匹配某个 Skill 的描述 2. **阅读 Skill 的完整说明**:使用 '{tool_name}' 工具加载 Skill 3. **遵循 Skill 说明操作**:工具结果包含逐步工作流程、最佳实践和示例 4. **访问支持文件**:Skill 可能包含辅助脚本、配置或参考文档 - 使用绝对路径访问 **何时使用 Skill:** - 用户请求匹配某个 Skill 的领域(例如"研究 X" -> web-research Skill) - 你需要专业知识或结构化工作流程 - 某个 Skill 为复杂任务提供了经过验证的模式 **执行 Skill 脚本:** Skill 可能包含 Python 脚本或其他可执行文件。始终使用绝对路径。 **示例工作流程:** 用户:"你能研究一下量子计算的最新发展吗?" 1. 检查可用 Skill -> 发现 "web-research" Skill 2. 调用 '{tool_name}' 工具读取完整的 Skill 说明 3. 遵循 Skill 的研究工作流程(搜索 -> 整理 -> 综合) 4. 使用绝对路径运行任何辅助脚本 记住:Skill 让你更加强大和稳定。如有疑问,请检查是否存在适用于该任务的 Skill! ` toolDescriptionBase = `Execute a skill within the main conversation When users ask you to perform tasks, check if any of the available skills below can help complete the task more effectively. Skills provide specialized capabilities and domain knowledge. How to invoke: - Use the exact string inside tag as the skill name (no arguments) - Examples: - ` + "`" + `skill: "pdf"` + "`" + ` - invoke the pdf skill - ` + "`" + `skill: "xlsx"` + "`" + ` - invoke the xlsx skill - ` + "`" + `skill: "ms-office-suite:pdf"` + "`" + ` - invoke using fully qualified name Important: - When a skill is relevant, you must invoke this tool IMMEDIATELY as your first action - NEVER just announce or mention a skill in your text response without actually calling this tool - This is a BLOCKING REQUIREMENT: invoke the relevant Skill tool BEFORE generating any other response about the task - Only use skills listed in below - Do not invoke a skill that is already running - Skill content may contain relative paths. Convert them to absolute paths using the base directory provided in the tool result ` toolDescriptionBaseChinese = `在主对话中执行 Skill(技能) 当用户要求你执行任务时,检查下方可用 Skill 列表中是否有 Skill 可以更有效地完成任务。Skill 提供专业能力和领域知识。 如何调用: - 使用 标签内的完整字符串作为 Skill 名称(无需其他参数) - 示例: - ` + "`" + `skill: "pdf"` + "`" + ` - 调用 pdf Skill - ` + "`" + `skill: "xlsx"` + "`" + ` - 调用 xlsx Skill - ` + "`" + `skill: "ms-office-suite:pdf"` + "`" + ` - 使用完全限定名称调用 重要说明: - 当 Skill 相关时,你必须立即调用此工具作为第一个动作 - 切勿仅在文本回复中提及 Skill 而不实际调用此工具 - 这是阻塞性要求:在生成任何关于任务的其他响应之前,先调用相关的 Skill 工具 - 仅使用 中列出的 Skill - 不要调用已经运行中的 Skill - Skill 内容中可能包含相对路径,需使用工具返回的 base directory 将其转换为绝对路径 ` toolDescriptionTemplate = ` {{- range .Matters }} {{ .Name }} {{ .Description }} {{- end }} ` toolResult = "Launching skill: %s\n" toolResultChinese = "正在启动 Skill:%s\n" userContent = `Base directory for this skill: %s %s` userContentChinese = `此 Skill 的目录:%s %s` toolName = "skill" subAgentResultFormat = "Skill \"%s\" completed (sub-agent execution).\n\nResult:\n%s" subAgentResultFormatChinese = "Skill \"%s\" 已完成(子 Agent 执行)。\n\n结果:\n%s" ) ================================================ FILE: adk/middlewares/skill/skill.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package skill provides the skill middleware, types, and a local filesystem backend. package skill import ( "bytes" "context" "encoding/json" "fmt" "strings" "text/template" "github.com/slongfield/pyfmt" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) type ContextMode string const ( // ContextModeForkWithContext forks a new agent to run the skill, // carrying over the original message history from the parent agent. ContextModeForkWithContext ContextMode = "fork_with_context" // ContextModeFork forks a new agent to run the skill // with a clean context, discarding the original message history. ContextModeFork ContextMode = "fork" ) type FrontMatter struct { Name string `yaml:"name"` Description string `yaml:"description"` Context ContextMode `yaml:"context"` Agent string `yaml:"agent"` Model string `yaml:"model"` } type Skill struct { FrontMatter Content string BaseDirectory string } type Backend interface { List(ctx context.Context) ([]FrontMatter, error) Get(ctx context.Context, name string) (Skill, error) } // AgentHubOptions contains options passed to AgentHub.Get when creating an agent for skill execution. type AgentHubOptions struct { // Model is the resolved model instance when a skill specifies a "model" field in frontmatter. // nil means the skill did not specify a model override; implementations should use their default. Model model.ToolCallingChatModel } // AgentHub provides agent instances for context mode (fork/fork_with_context) execution. type AgentHub interface { // Get returns an Agent by name. When name is empty, implementations should return a default agent. // The opts parameter carries skill-level overrides (e.g., model) resolved by the framework. Get(ctx context.Context, name string, opts *AgentHubOptions) (adk.Agent, error) } // ModelHub resolves model instances by name for skills that specify a "model" field in frontmatter. type ModelHub interface { Get(ctx context.Context, name string) (model.ToolCallingChatModel, error) } // SystemPromptFunc is a function that returns a custom system prompt. // The toolName parameter is the name of the skill tool (default: "skill"). type SystemPromptFunc func(ctx context.Context, toolName string) string // ToolDescriptionFunc is a function that returns a custom tool description. // The skills parameter contains all available skill front matters. type ToolDescriptionFunc func(ctx context.Context, skills []FrontMatter) string // Config is the configuration for the skill middleware. type Config struct { // Backend is the backend for retrieving skills. Backend Backend // SkillToolName is the custom name for the skill tool. If nil, the default name "skill" is used. SkillToolName *string // Deprecated: Use adk.SetLanguage(adk.LanguageChinese) instead to enable Chinese prompts globally. // This field will be removed in a future version. UseChinese bool // AgentHub provides agent factories for context mode (fork/isolate) execution. // Required when skills use "context: fork" or "context: isolate" in frontmatter. // The agent factory is retrieved by agent name (skill.Agent) from this hub. // When skill.Agent is empty, AgentHub.Get is called with an empty string, // allowing the hub implementation to return a default agent. AgentHub AgentHub // ModelHub provides model instances for skills that specify a "model" field in frontmatter. // Used in two scenarios: // - With context mode (fork/isolate): The model is passed to the AgentFactory // - Without context mode (inline): The model becomes active for subsequent ChatModel requests // If nil, skills with model specification will be ignored in inline mode, // or return an error in context mode. ModelHub ModelHub // CustomSystemPrompt allows customizing the system prompt injected into the agent. // If nil, the default system prompt is used. // The function receives the skill tool name as a parameter. CustomSystemPrompt SystemPromptFunc // CustomToolDescription allows customizing the tool description for the skill tool. // If nil, the default tool description is used. // The function receives all available skill front matters as a parameter. CustomToolDescription ToolDescriptionFunc } // NewMiddleware creates a new skill middleware handler for ChatModelAgent. // // The handler provides a skill tool that allows agents to load and execute skills // defined in SKILL.md files. Skills can run in different modes based on their // frontmatter configuration: // // - Inline mode (default): Skill content is returned directly as tool result // - Fork mode (context: fork): Forks a new agent with a clean context, discarding message history // - Fork with context mode (context: fork_with_context): Forks a new agent carrying over message history // // Example usage: // // handler, err := skill.NewMiddleware(ctx, &skill.Config{ // Backend: backend, // AgentHub: myAgentHub, // ModelHub: myModelHub, // }) // if err != nil { // return err // } // // agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ // // ... // Middlewares: []adk.ChatModelAgentMiddleware{handler}, // }) func NewMiddleware(ctx context.Context, config *Config) (adk.ChatModelAgentMiddleware, error) { if config == nil { return nil, fmt.Errorf("config is required") } if config.Backend == nil { return nil, fmt.Errorf("backend is required") } name := toolName if config.SkillToolName != nil { name = *config.SkillToolName } var instruction string if config.CustomSystemPrompt != nil { instruction = config.CustomSystemPrompt(ctx, name) } else { var err error instruction, err = buildSystemPrompt(name, config.UseChinese) if err != nil { return nil, err } } return &skillHandler{ instruction: instruction, tool: &skillTool{ b: config.Backend, toolName: name, useChinese: config.UseChinese, agentHub: config.AgentHub, modelHub: config.ModelHub, customToolDescription: config.CustomToolDescription, }, }, nil } type skillHandler struct { *adk.BaseChatModelAgentMiddleware instruction string tool *skillTool } func (h *skillHandler) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { runCtx.Instruction = runCtx.Instruction + "\n" + h.instruction runCtx.Tools = append(runCtx.Tools, h.tool) return ctx, runCtx, nil } func (h *skillHandler) WrapModel(ctx context.Context, m model.BaseChatModel, mc *adk.ModelContext) (model.BaseChatModel, error) { if h.tool.modelHub == nil { return m, nil } modelName, found, err := adk.GetRunLocalValue(ctx, activeModelKey) if err != nil { return nil, fmt.Errorf("failed to get active model from run local value: %w", err) } if !found { return m, nil } name, ok := modelName.(string) if !ok || name == "" { return m, nil } newModel, err := h.tool.modelHub.Get(ctx, name) if err != nil { return nil, fmt.Errorf("failed to get model '%s' from ModelHub: %w", name, err) } return newModel, nil } const activeModelKey = "__skill_active_model__" // New creates a new skill middleware. // It provides a tool for the agent to use skills. // // Deprecated: Use NewChatModelAgentMiddleware instead. New does not support fork mode execution // because AgentMiddleware cannot save message history for fork mode. func New(ctx context.Context, config *Config) (adk.AgentMiddleware, error) { if config == nil { return adk.AgentMiddleware{}, fmt.Errorf("config is required") } if config.Backend == nil { return adk.AgentMiddleware{}, fmt.Errorf("backend is required") } name := toolName if config.SkillToolName != nil { name = *config.SkillToolName } var sp string if config.CustomSystemPrompt != nil { sp = config.CustomSystemPrompt(ctx, name) } else { var err error sp, err = buildSystemPrompt(name, config.UseChinese) if err != nil { return adk.AgentMiddleware{}, err } } return adk.AgentMiddleware{ AdditionalInstruction: sp, AdditionalTools: []tool.BaseTool{&skillTool{ b: config.Backend, toolName: name, useChinese: config.UseChinese, customToolDescription: config.CustomToolDescription, }}, }, nil } func buildSystemPrompt(skillToolName string, useChinese bool) (string, error) { var prompt string if useChinese { prompt = systemPromptChinese } else { prompt = internal.SelectPrompt(internal.I18nPrompts{ English: systemPrompt, Chinese: systemPromptChinese, }) } return pyfmt.Fmt(prompt, map[string]string{ "tool_name": skillToolName, }) } type skillTool struct { b Backend toolName string useChinese bool agentHub AgentHub modelHub ModelHub customToolDescription ToolDescriptionFunc } type descriptionTemplateHelper struct { Matters []FrontMatter } func (s *skillTool) Info(ctx context.Context) (*schema.ToolInfo, error) { skills, err := s.b.List(ctx) if err != nil { return nil, fmt.Errorf("failed to list skills: %w", err) } var fullDesc string if s.customToolDescription != nil { fullDesc = s.customToolDescription(ctx, skills) } else { desc, err := renderToolDescription(skills) if err != nil { return nil, fmt.Errorf("failed to render skill tool description: %w", err) } descBase := internal.SelectPrompt(internal.I18nPrompts{ English: toolDescriptionBase, Chinese: toolDescriptionBaseChinese, }) fullDesc = descBase + desc } paramDesc := internal.SelectPrompt(internal.I18nPrompts{ English: "The skill name (no arguments). E.g., \"pdf\" or \"xlsx\"", Chinese: "Skill 名称(无需其他参数)。例如:\"pdf\" 或 \"xlsx\"", }) return &schema.ToolInfo{ Name: s.toolName, Desc: fullDesc, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "skill": { Type: schema.String, Desc: paramDesc, Required: true, }, }), }, nil } type inputArguments struct { Skill string `json:"skill"` } func (s *skillTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { args := &inputArguments{} err := json.Unmarshal([]byte(argumentsInJSON), args) if err != nil { return "", fmt.Errorf("failed to unmarshal arguments: %w", err) } skill, err := s.b.Get(ctx, args.Skill) if err != nil { return "", fmt.Errorf("failed to get skill: %w", err) } switch skill.Context { case ContextModeForkWithContext: return s.runAgentMode(ctx, skill, true) case ContextModeFork: return s.runAgentMode(ctx, skill, false) default: if skill.Model != "" { s.setActiveModel(ctx, skill.Model) } return s.buildSkillResult(skill) } } func (s *skillTool) setActiveModel(ctx context.Context, modelName string) { _ = adk.SetRunLocalValue(ctx, activeModelKey, modelName) } func (s *skillTool) buildSkillResult(skill Skill) (string, error) { resultFmt := internal.SelectPrompt(internal.I18nPrompts{ English: toolResult, Chinese: toolResultChinese, }) contentFmt := internal.SelectPrompt(internal.I18nPrompts{ English: userContent, Chinese: userContentChinese, }) return fmt.Sprintf(resultFmt, skill.Name) + fmt.Sprintf(contentFmt, skill.BaseDirectory, skill.Content), nil } func (s *skillTool) runAgentMode(ctx context.Context, skill Skill, forkHistory bool) (string, error) { if s.agentHub == nil { return "", fmt.Errorf("skill '%s' requires context:%s but AgentHub is not configured", skill.Name, skill.Context) } opts := &AgentHubOptions{} if skill.Model != "" { if s.modelHub == nil { return "", fmt.Errorf("skill '%s' requires model '%s' but ModelHub is not configured", skill.Name, skill.Model) } m, err := s.modelHub.Get(ctx, skill.Model) if err != nil { return "", fmt.Errorf("failed to get model '%s' from ModelHub: %w", skill.Model, err) } opts.Model = m } agent, err := s.agentHub.Get(ctx, skill.Agent, opts) if err != nil { return "", fmt.Errorf("failed to get agent '%s' from AgentHub: %w", skill.Agent, err) } var messages []adk.Message skillContent, err := s.buildSkillResult(skill) if err != nil { return "", fmt.Errorf("failed to build skill result: %w", err) } if forkHistory { messages, err = s.getMessagesFromState(ctx) if err != nil { return "", fmt.Errorf("failed to get messages from state: %w", err) } toolCallID := compose.GetToolCallID(ctx) messages = append(messages, schema.ToolMessage(skillContent, toolCallID)) } else { messages = []adk.Message{schema.UserMessage(skillContent)} } input := &adk.AgentInput{ Messages: messages, EnableStreaming: false, } iter := agent.Run(ctx, input) var results []string for { event, ok := iter.Next() if !ok { break } if event.Err != nil { return "", fmt.Errorf("failed to run agent event: %w", event.Err) } if event.Output == nil || event.Output.MessageOutput == nil { continue } msg, msgErr := event.Output.MessageOutput.GetMessage() if msgErr != nil { return "", fmt.Errorf("failed to get message from event: %w", msgErr) } if msg != nil && msg.Content != "" { results = append(results, msg.Content) } } resultFmt := internal.SelectPrompt(internal.I18nPrompts{ English: subAgentResultFormat, Chinese: subAgentResultFormatChinese, }) return fmt.Sprintf(resultFmt, skill.Name, strings.Join(results, "\n")), nil } func (s *skillTool) getMessagesFromState(ctx context.Context) ([]adk.Message, error) { var messages []adk.Message err := compose.ProcessState(ctx, func(_ context.Context, st *adk.State) error { messages = make([]adk.Message, len(st.Messages)) copy(messages, st.Messages) return nil }) if err != nil { return nil, fmt.Errorf("failed to process state: %w", err) } return messages, nil } func renderToolDescription(matters []FrontMatter) (string, error) { tpl, err := template.New("skills").Parse(toolDescriptionTemplate) if err != nil { return "", err } var buf bytes.Buffer err = tpl.Execute(&buf, descriptionTemplateHelper{Matters: matters}) if err != nil { return "", err } return buf.String(), nil } ================================================ FILE: adk/middlewares/skill/skill_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package skill import ( "context" "errors" "fmt" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) type inMemoryBackend struct { m []Skill } func (i *inMemoryBackend) List(ctx context.Context) ([]FrontMatter, error) { matters := make([]FrontMatter, 0, len(i.m)) for _, skill := range i.m { matters = append(matters, skill.FrontMatter) } return matters, nil } func (i *inMemoryBackend) Get(ctx context.Context, name string) (Skill, error) { for _, skill := range i.m { if skill.Name == name { return skill, nil } } return Skill{}, errors.New("skill not found") } func TestTool(t *testing.T) { backend := &inMemoryBackend{m: []Skill{ { FrontMatter: FrontMatter{ Name: "name1", Description: "desc1", }, Content: "content1", BaseDirectory: "basedir1", }, { FrontMatter: FrontMatter{ Name: "name2", Description: "desc2", }, Content: "content2", BaseDirectory: "basedir2", }, }} ctx := context.Background() m, err := New(ctx, &Config{Backend: backend}) assert.NoError(t, err) assert.Len(t, m.AdditionalTools, 1) to := m.AdditionalTools[0].(tool.InvokableTool) info, err := to.Info(ctx) assert.NoError(t, err) assert.Equal(t, "skill", info.Name) desc := strings.TrimPrefix(info.Desc, toolDescriptionBase) assert.Equal(t, ` name1 desc1 name2 desc2 `, desc) result, err := to.InvokableRun(ctx, `{"skill": "name1"}`) assert.NoError(t, err) assert.Equal(t, `Launching skill: name1 Base directory for this skill: basedir1 content1`, result) // chinese internal.SetLanguage(internal.LanguageChinese) defer internal.SetLanguage(internal.LanguageEnglish) m, err = New(ctx, &Config{Backend: backend}) assert.NoError(t, err) assert.Len(t, m.AdditionalTools, 1) to = m.AdditionalTools[0].(tool.InvokableTool) info, err = to.Info(ctx) assert.NoError(t, err) assert.Equal(t, "skill", info.Name) desc = strings.TrimPrefix(info.Desc, toolDescriptionBaseChinese) assert.Equal(t, ` name1 desc1 name2 desc2 `, desc) result, err = to.InvokableRun(ctx, `{"skill": "name1"}`) assert.NoError(t, err) assert.Equal(t, `正在启动 Skill:name1 此 Skill 的目录:basedir1 content1`, result) } func TestSkillToolName(t *testing.T) { ctx := context.Background() // default m, err := New(ctx, &Config{Backend: &inMemoryBackend{m: []Skill{}}}) assert.NoError(t, err) // instruction assert.Contains(t, m.AdditionalInstruction, "'skill'") // tool name info, err := m.AdditionalTools[0].Info(ctx) assert.NoError(t, err) assert.Equal(t, "skill", info.Name) // customized name := "load_skill" m, err = New(ctx, &Config{Backend: &inMemoryBackend{m: []Skill{}}, SkillToolName: &name}) assert.NoError(t, err) assert.Contains(t, m.AdditionalInstruction, "'load_skill'") info, err = m.AdditionalTools[0].Info(ctx) assert.NoError(t, err) assert.Equal(t, "load_skill", info.Name) } // --- Mock types for NewMiddleware tests --- type mockModel struct { model.ToolCallingChatModel name string } type mockModelHub struct { models map[string]model.ToolCallingChatModel } func (h *mockModelHub) Get(_ context.Context, name string) (model.ToolCallingChatModel, error) { m, ok := h.models[name] if !ok { return nil, fmt.Errorf("model not found: %s", name) } return m, nil } type mockAgent struct { events []*adk.AgentEvent } func (a *mockAgent) Name(_ context.Context) string { return "mock-agent" } func (a *mockAgent) Description(_ context.Context) string { return "mock agent for testing" } func (a *mockAgent) Run(_ context.Context, _ *adk.AgentInput, _ ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]() go func() { defer gen.Close() for _, e := range a.events { gen.Send(e) } }() return iter } type mockAgentHub struct { agents map[string]adk.Agent lastOpts *AgentHubOptions defaultAgent adk.Agent } func (h *mockAgentHub) Get(_ context.Context, name string, opts *AgentHubOptions) (adk.Agent, error) { h.lastOpts = opts if name == "" && h.defaultAgent != nil { return h.defaultAgent, nil } a, ok := h.agents[name] if !ok { return nil, fmt.Errorf("agent not found: %s", name) } return a, nil } type errorBackend struct { listErr error getErr error } func (b *errorBackend) List(_ context.Context) ([]FrontMatter, error) { return nil, b.listErr } func (b *errorBackend) Get(_ context.Context, _ string) (Skill, error) { return Skill{}, b.getErr } // --- NewMiddleware tests --- func TestNewMiddleware(t *testing.T) { ctx := context.Background() t.Run("nil config returns error", func(t *testing.T) { handler, err := NewMiddleware(ctx, nil) assert.Nil(t, handler) assert.Error(t, err) assert.Contains(t, err.Error(), "config is required") }) t.Run("nil backend returns error", func(t *testing.T) { handler, err := NewMiddleware(ctx, &Config{}) assert.Nil(t, handler) assert.Error(t, err) assert.Contains(t, err.Error(), "backend is required") }) t.Run("valid config succeeds", func(t *testing.T) { backend := &inMemoryBackend{m: []Skill{}} handler, err := NewMiddleware(ctx, &Config{Backend: backend}) assert.NoError(t, err) assert.NotNil(t, handler) }) t.Run("custom tool name", func(t *testing.T) { backend := &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Description: "d1"}, Content: "c1"}, }} name := "load_skill" handler, err := NewMiddleware(ctx, &Config{Backend: backend, SkillToolName: &name}) require.NoError(t, err) h := handler.(*skillHandler) assert.Contains(t, h.instruction, "'load_skill'") assert.Equal(t, "load_skill", h.tool.toolName) }) t.Run("custom system prompt", func(t *testing.T) { backend := &inMemoryBackend{m: []Skill{}} handler, err := NewMiddleware(ctx, &Config{ Backend: backend, CustomSystemPrompt: func(_ context.Context, toolName string) string { return "custom prompt for " + toolName }, }) require.NoError(t, err) h := handler.(*skillHandler) assert.Equal(t, "custom prompt for skill", h.instruction) }) t.Run("custom tool description", func(t *testing.T) { backend := &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Description: "d1"}, Content: "c1"}, }} handler, err := NewMiddleware(ctx, &Config{ Backend: backend, CustomToolDescription: func(_ context.Context, skills []FrontMatter) string { return fmt.Sprintf("custom desc with %d skills", len(skills)) }, }) require.NoError(t, err) h := handler.(*skillHandler) info, err := h.tool.Info(ctx) require.NoError(t, err) assert.Equal(t, "custom desc with 1 skills", info.Desc) }) } func TestBeforeAgent(t *testing.T) { ctx := context.Background() backend := &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Description: "d1"}, Content: "c1"}, }} handler, err := NewMiddleware(ctx, &Config{Backend: backend}) require.NoError(t, err) runCtx := &adk.ChatModelAgentContext{ Instruction: "base instruction", Tools: []tool.BaseTool{}, } _, newRunCtx, err := handler.BeforeAgent(ctx, runCtx) assert.NoError(t, err) assert.Contains(t, newRunCtx.Instruction, "base instruction") assert.Contains(t, newRunCtx.Instruction, "Skills System") assert.Len(t, newRunCtx.Tools, 1) // verify the added tool is the skill tool info, err := newRunCtx.Tools[0].Info(ctx) assert.NoError(t, err) assert.Equal(t, "skill", info.Name) } func TestSkillToolInfo(t *testing.T) { ctx := context.Background() t.Run("list error propagates", func(t *testing.T) { st := &skillTool{ b: &errorBackend{listErr: errors.New("list failed")}, toolName: "skill", } info, err := st.Info(ctx) assert.Nil(t, info) assert.Error(t, err) assert.Contains(t, err.Error(), "list failed") }) t.Run("description contains all skills", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "alpha", Description: "desc-alpha"}}, {FrontMatter: FrontMatter{Name: "beta", Description: "desc-beta"}}, }}, toolName: "skill", } info, err := st.Info(ctx) require.NoError(t, err) assert.Contains(t, info.Desc, "alpha") assert.Contains(t, info.Desc, "desc-alpha") assert.Contains(t, info.Desc, "beta") assert.Contains(t, info.Desc, "desc-beta") }) } func TestInvokableRun_InlineMode(t *testing.T) { ctx := context.Background() t.Run("invalid json returns error", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{}}, toolName: "skill", } _, err := st.InvokableRun(ctx, "not json") assert.Error(t, err) assert.Contains(t, err.Error(), "failed to unmarshal") }) t.Run("skill not found returns error", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{}}, toolName: "skill", } _, err := st.InvokableRun(ctx, `{"skill": "nonexistent"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to get skill") }) t.Run("inline mode returns skill content", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{ { FrontMatter: FrontMatter{Name: "pdf", Description: "PDF processing"}, Content: "Process PDF files here", BaseDirectory: "/skills/pdf", }, }}, toolName: "skill", } result, err := st.InvokableRun(ctx, `{"skill": "pdf"}`) assert.NoError(t, err) assert.Contains(t, result, "pdf") assert.Contains(t, result, "/skills/pdf") assert.Contains(t, result, "Process PDF files here") }) } func TestInvokableRun_AgentMode(t *testing.T) { ctx := context.Background() t.Run("fork mode without AgentHub returns error", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Context: ContextModeFork}, Content: "c1"}, }}, toolName: "skill", } _, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "AgentHub is not configured") }) t.Run("fork_with_context mode without AgentHub returns error", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Context: ContextModeForkWithContext}, Content: "c1"}, }}, toolName: "skill", } _, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "AgentHub is not configured") }) t.Run("model specified without ModelHub returns error", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Context: ContextModeFork, Model: "gpt-4"}, Content: "c1"}, }}, toolName: "skill", agentHub: &mockAgentHub{}, } _, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "ModelHub is not configured") }) t.Run("model not found in ModelHub returns error", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Context: ContextModeFork, Model: "gpt-4"}, Content: "c1"}, }}, toolName: "skill", agentHub: &mockAgentHub{}, modelHub: &mockModelHub{models: map[string]model.ToolCallingChatModel{}}, } _, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to get model") }) t.Run("agent not found in AgentHub returns error", func(t *testing.T) { st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Context: ContextModeFork, Agent: "nonexistent"}, Content: "c1"}, }}, toolName: "skill", agentHub: &mockAgentHub{agents: map[string]adk.Agent{}}, } _, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to get agent") }) t.Run("fork mode runs agent and returns result", func(t *testing.T) { agent := &mockAgent{ events: []*adk.AgentEvent{ { Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Message: schema.AssistantMessage("agent response", nil), }, }, }, }, } hub := &mockAgentHub{defaultAgent: agent} st := &skillTool{ b: &inMemoryBackend{m: []Skill{ { FrontMatter: FrontMatter{Name: "test-skill", Context: ContextModeFork}, Content: "skill content", BaseDirectory: "/skills/test", }, }}, toolName: "skill", agentHub: hub, } result, err := st.InvokableRun(ctx, `{"skill": "test-skill"}`) assert.NoError(t, err) assert.Contains(t, result, "test-skill") assert.Contains(t, result, "agent response") assert.Contains(t, result, "completed") // verify no model was passed in opts assert.NotNil(t, hub.lastOpts) assert.Nil(t, hub.lastOpts.Model) }) t.Run("fork mode with model passes model to AgentHub", func(t *testing.T) { m := &mockModel{name: "test-model"} agent := &mockAgent{ events: []*adk.AgentEvent{ { Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Message: schema.AssistantMessage("response", nil), }, }, }, }, } hub := &mockAgentHub{defaultAgent: agent} st := &skillTool{ b: &inMemoryBackend{m: []Skill{ { FrontMatter: FrontMatter{Name: "s1", Context: ContextModeFork, Model: "test-model"}, Content: "c1", BaseDirectory: "/skills/s1", }, }}, toolName: "skill", agentHub: hub, modelHub: &mockModelHub{models: map[string]model.ToolCallingChatModel{"test-model": m}}, } result, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.NoError(t, err) assert.Contains(t, result, "s1") // verify model was passed assert.NotNil(t, hub.lastOpts) assert.Equal(t, m, hub.lastOpts.Model) }) t.Run("agent returns multiple events", func(t *testing.T) { agent := &mockAgent{ events: []*adk.AgentEvent{ { Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Message: schema.AssistantMessage("part1", nil), }, }, }, {Output: nil}, // nil output should be skipped { Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Message: schema.AssistantMessage("part2", nil), }, }, }, }, } hub := &mockAgentHub{defaultAgent: agent} st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Context: ContextModeFork}, Content: "c1", BaseDirectory: "/d"}, }}, toolName: "skill", agentHub: hub, } result, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.NoError(t, err) assert.Contains(t, result, "part1") assert.Contains(t, result, "part2") }) t.Run("agent returns empty content events", func(t *testing.T) { agent := &mockAgent{ events: []*adk.AgentEvent{ { Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Message: schema.AssistantMessage("", nil), }, }, }, }, } hub := &mockAgentHub{defaultAgent: agent} st := &skillTool{ b: &inMemoryBackend{m: []Skill{ {FrontMatter: FrontMatter{Name: "s1", Context: ContextModeFork}, Content: "c1", BaseDirectory: "/d"}, }}, toolName: "skill", agentHub: hub, } result, err := st.InvokableRun(ctx, `{"skill": "s1"}`) assert.NoError(t, err) // result should contain skill name but no extra content assert.Contains(t, result, "s1") }) } ================================================ FILE: adk/middlewares/summarization/consts.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package summarization const ( extraKeyContentType = "_eino_summarization_content_type" ) type summarizationContentType string const ( contentTypeSummary summarizationContentType = "summary" ) type ActionType string const ( ActionTypeBeforeSummarize ActionType = "before_summarize" ActionTypeAfterSummarize ActionType = "after_summarize" ) ================================================ FILE: adk/middlewares/summarization/customized_action.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package summarization import ( "github.com/cloudwego/eino/adk" ) type CustomizedAction struct { // Type is the action type. Type ActionType `json:"type"` // Before is set when Type is ActionTypeBeforeSummarize. // Emitted after trigger condition is met, before calling model to generate summary. Before *BeforeSummarizeAction `json:"before,omitempty"` // After is set when Type is ActionTypeAfterSummarize. // Emitted after summarization. After *AfterSummarizeAction `json:"after,omitempty"` } type BeforeSummarizeAction struct { // Messages is the original state messages before summarization. Messages []adk.Message `json:"messages,omitempty"` } type AfterSummarizeAction struct { // Messages is the final state messages after summarization. Messages []adk.Message `json:"messages,omitempty"` } ================================================ FILE: adk/middlewares/summarization/prompt.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package summarization import ( "regexp" "github.com/cloudwego/eino/adk/internal" ) var allUserMessagesTagRegex = regexp.MustCompile(`(?s).*`) func getSystemInstruction() string { return internal.SelectPrompt(internal.I18nPrompts{ English: systemInstruction, Chinese: systemInstructionZh, }) } func getUserSummaryInstruction() string { return internal.SelectPrompt(internal.I18nPrompts{ English: userSummaryInstruction, Chinese: userSummaryInstructionZh, }) } func getSummaryPreamble() string { return internal.SelectPrompt(internal.I18nPrompts{ English: summaryPreamble, Chinese: summaryPreambleZh, }) } func getContinueInstruction() string { return internal.SelectPrompt(internal.I18nPrompts{ English: continueInstruction, Chinese: continueInstructionZh, }) } func getTranscriptPathInstruction() string { return internal.SelectPrompt(internal.I18nPrompts{ English: transcriptPathInstruction, Chinese: transcriptPathInstructionZh, }) } func getTruncatedMarkerFormat() string { return internal.SelectPrompt(internal.I18nPrompts{ English: truncatedMarkerFormat, Chinese: truncatedMarkerFormatZh, }) } func getUserMessagesReplacedNote() string { return internal.SelectPrompt(internal.I18nPrompts{ English: userMessagesReplacedNote, Chinese: userMessagesReplacedNoteZh, }) } const systemInstruction = `You are a helpful AI assistant tasked with summarizing conversations.` const systemInstructionZh = `你是一个负责总结对话的 AI 助手。` const userSummaryInstruction = `Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions. This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing development work without losing context. Before providing your final summary, wrap your analysis in tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process: 1. Chronologically analyze each message and section of the conversation. For each section thoroughly identify: - The user's explicit requests and intents - Your approach to addressing the user's requests - Key decisions, technical concepts and code patterns - Specific details like: - file names - full code snippets - function signatures - file edits - Errors that you ran into and how you fixed them - Pay special attention to specific user feedback that you received, especially if the user told you to do something differently. 2. Double-check for technical accuracy and completeness, addressing each required element thoroughly. Your summary should include the following sections: 1. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail 2. Key Technical Concepts: List all important technical concepts, technologies, and frameworks discussed. 3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important. 4. Errors and fixes: List all errors that you ran into, and how you fixed them. Pay special attention to specific user feedback that you received, especially if the user told you to do something differently. 5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts. 6. All user messages: List ALL user messages that are not tool results, and wrap them in the ... block. These are critical for understanding the users' feedback and changing intent. 7. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on. 8. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable. 9. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's most recent explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests or really old requests that were already completed without confirming with the user first. If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation. Here's an example of how your output should be structured: [Your thought process, ensuring all points are covered thoroughly and accurately] 1. Primary Request and Intent: [Detailed description] 2. Key Technical Concepts: - [Concept 1] - [Concept 2] - [...] 3. Files and Code Sections: - [File Name 1] - [Summary of why this file is important] - [Summary of the changes made to this file, if any] - [Important Code Snippet] - [File Name 2] - [Important Code Snippet] - [...] 4. Errors and fixes: - [Detailed description of error 1]: - [How you fixed the error] - [User feedback on the error if any] - [...] 5. Problem Solving: [Description of solved problems and ongoing troubleshooting] 6. All user messages: - [Detailed non tool use user message] - [...] 7. Pending Tasks: - [Task 1] - [Task 2] - [...] 8. Current Work: [Precise description of current work] 9. Optional Next Step: [Optional Next step to take] Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response. There may be additional summarization instructions provided in the included context. If so, remember to follow these instructions when creating the above summary. Examples of instructions include: ## Compact Instructions When summarizing the conversation focus on typescript code changes and also remember the mistakes you made and how you fixed them. # Summary instructions When you are using compact - please focus on test output and code changes. Include file reads verbatim. IMPORTANT: Do NOT use any tools. You MUST respond with ONLY the ... block as your text output. ` const userSummaryInstructionZh = `你的任务是对目前为止的对话创建一份详细的总结,需要密切关注用户的明确请求和你之前的操作。 这份总结应该全面捕捉技术细节、代码模式和架构决策,以确保继续开发工作时不丢失上下文。 在提供最终总结之前,请将你的分析过程包裹在 标签中,以组织思路并确保涵盖所有必要的要点。在分析过程中: 1. 按时间顺序分析对话中的每条消息和每个部分。对于每个部分,需要全面识别: - 用户的明确请求和意图 - 你处理用户请求的方法 - 关键决策、技术概念和代码模式 - 具体细节,例如: - 文件名 - 完整代码片段 - 函数签名 - 文件编辑 - 你遇到的错误以及如何修复它们 - 特别注意你收到的具体用户反馈,尤其是用户要求你以不同方式处理的情况 2. 仔细检查技术准确性和完整性,彻底处理每个必需的元素。 你的总结应包含以下部分: 1. 主要请求和意图:详细捕捉用户所有的明确请求和意图 2. 关键技术概念:列出讨论过的所有重要技术概念、技术和框架 3. 文件和代码部分:列举检查、修改或创建的具体文件和代码部分。特别注意最近的消息,在适用的地方包含完整的代码片段,并总结为什么这个文件的读取或编辑很重要 4. 错误和修复:列出你遇到的所有错误以及如何修复它们。特别注意你收到的具体用户反馈,尤其是用户要求你以不同方式处理的情况 5. 问题解决:记录已解决的问题和任何正在进行的故障排除工作 6. 所有用户消息:列出所有非工具结果的用户消息,并将它们包裹在 ... 块中。这些对于理解用户的反馈和变化的意图至关重要 7. 待处理任务:列出明确要求你处理的任何待处理任务 8. 当前工作:详细描述在此总结请求之前正在进行的工作,特别注意用户和助手的最近消息。在适用的地方包含文件名和代码片段 9. 可选的下一步:列出与你最近工作相关的下一步操作。重要提示:确保这一步与用户最近的明确请求以及你在此总结请求之前正在处理的任务直接相关。如果你的上一个任务已经完成,则只有在与用户请求明确相关时才列出下一步。不要在未与用户确认的情况下开始处理无关的请求或已经完成的旧请求。 如果有下一步,请包含最近对话中的直接引用,准确显示你正在处理的任务以及你停止的位置。这应该是逐字引用,以确保任务理解不会偏离。 以下是输出结构的示例: [你的思考过程,确保全面准确地涵盖所有要点] 1. 主要请求和意图: [详细描述] 2. 关键技术概念: - [概念 1] - [概念 2] - [...] 3. 文件和代码部分: - [文件名 1] - [为什么这个文件重要的总结] - [对这个文件所做更改的总结(如有)] - [重要代码片段] - [文件名 2] - [重要代码片段] - [...] 4. 错误和修复: - [错误 1 的详细描述]: - [如何修复该错误] - [用户对该错误的反馈(如有)] - [...] 5. 问题解决: [已解决问题和正在进行的故障排除的描述] 6. 所有用户消息: - [详细的非工具使用用户消息] - [...] 7. 待处理任务: - [任务 1] - [任务 2] - [...] 8. 当前工作: [当前工作的精确描述] 9. 可选的下一步: [可选的下一步操作] 请根据目前为止的对话提供你的总结,遵循此结构并确保回复的精确性和全面性。 上下文中可能包含额外的总结指令。如果有,请在创建上述总结时记得遵循这些指令。指令示例包括: ## 压缩指令 在总结对话时,重点关注 typescript 代码更改,并记住你犯的错误以及如何修复它们。 # 总结指令 当你使用压缩时,请重点关注测试输出和代码更改。逐字包含文件读取内容。 重要提示:不要使用任何工具。你必须只以 ... 块作为文本输出进行回复。 ` const summaryPreamble = `This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.` const summaryPreambleZh = `此会话延续自此前一段因上下文耗尽而终止的对话。以下总结概述了此前对话的内容。` const continueInstruction = `Please continue the conversation from where we left it off without asking the user any further questions. Continue with the last task that you were asked to work on.` const continueInstructionZh = `请从我们中断的地方继续对话,无需向用户提出任何进一步的问题。继续完成先前指令中未完成的任务。` const transcriptPathInstruction = `If you need specific details from before compaction (like exact code snippets, error messages, or content you generated), read the full transcript at: %s` const transcriptPathInstructionZh = `如果你需要压缩之前的具体细节(如精确的代码片段、错误消息或你生成的内容),完整的对话记录位于:%s` const truncatedMarkerFormat = "…%d characters truncated…" const truncatedMarkerFormatZh = "…已截断 %d 个字符…" const userMessagesReplacedNote = "Some earlier user messages have been cleared. Below are the most recent user messages:" const userMessagesReplacedNoteZh = "部分较早的用户消息已被清除,以下是保留的最近用户消息:" ================================================ FILE: adk/middlewares/summarization/summarization.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package summarization provides a middleware that automatically summarizes // conversation history when token count exceeds the configured threshold. package summarization import ( "context" "fmt" "regexp" "strings" "unicode/utf8" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" ) func init() { schema.RegisterName[*CustomizedAction]("_eino_adk_summarization_mw_customized_action") } type ( TokenCounterFunc func(ctx context.Context, input *TokenCounterInput) (int, error) GenModelInputFunc func(ctx context.Context, defaultSystemInstruction, userInstruction adk.Message, originalMsgs []adk.Message) ([]adk.Message, error) FinalizeFunc func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) CallbackFunc func(ctx context.Context, before, after adk.ChatModelAgentState) error ) // Config defines the configuration for the summarization middleware. type Config struct { // Model is the chat model used to generate summaries. Model model.BaseChatModel // ModelOptions specifies options passed to the model when generating summaries. // Optional. ModelOptions []model.Option // TokenCounter calculates the token count for a message. // Optional. Defaults to a simple estimator (~4 chars/token). TokenCounter TokenCounterFunc // Trigger specifies the conditions that activate summarization. // Optional. Defaults to triggering when total tokens exceed 190k. Trigger *TriggerCondition // EmitInternalEvents indicates whether internal events should be emitted during summarization, // allowing external observers to track the summarization process. // // Event Scoping: // - ActionTypeBeforeSummarize: emitted before calling model to generate summary // - ActionTypeAfterSummarize: emitted after summary generation completes // Optional. Defaults to false. EmitInternalEvents bool // UserInstruction serves as the user-level instruction to guide the model on how to summarize the context. // It is appended to the message history as a User message. // If provided, it overrides the default user summarization instruction. // Optional. UserInstruction string // TranscriptFilePath is the path to the file containing the full conversation history. // It is appended to the summary to remind the model where to read the original context. // Optional but strongly recommended. TranscriptFilePath string // GenModelInput allows full control over the summarization model input construction. // // Parameters: // - defaultSystemInstruction: System message defining the model's role // - userInstruction: User message with the task instruction // - originalMsgs: original complete message list // // Typical model input order: systemInstruction -> contextMessages -> userInstruction. // // Optional. GenModelInput GenModelInputFunc // Finalize is called after summary generation. The returned messages are used as the final output. // Optional. Finalize FinalizeFunc // Callback is called after Finalize, before exiting the middleware. // Read-only, do not modify state. // Optional. Callback CallbackFunc // PreserveUserMessages controls whether to preserve original user messages in the summary. // When enabled, replaces the section in the model-generated summary // with recent original user messages from the conversation. // When disabled, the model-generated content is kept unchanged. // Optional. Enabled by default. PreserveUserMessages *PreserveUserMessages } type TokenCounterInput struct { Messages []adk.Message Tools []*schema.ToolInfo } // TriggerCondition specifies when summarization should be activated. // Summarization triggers if ANY of the set conditions is met. type TriggerCondition struct { // ContextTokens triggers summarization when total token count exceeds this threshold. ContextTokens int // ContextMessages triggers summarization when total messages count exceeds this threshold. ContextMessages int } // PreserveUserMessages controls whether to preserve original user messages in the summary. type PreserveUserMessages struct { Enabled bool // MaxTokens limits the maximum token count for preserved user messages. // When set, only the most recent user messages within this limit are preserved. // Optional. Defaults to 1/3 of TriggerCondition.ContextTokens if not specified. MaxTokens int } // New creates a summarization middleware that automatically summarizes conversation history // when trigger conditions are met. func New(ctx context.Context, cfg *Config) (adk.ChatModelAgentMiddleware, error) { if err := cfg.check(); err != nil { return nil, err } return &middleware{ cfg: cfg, BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, }, nil } type middleware struct { *adk.BaseChatModelAgentMiddleware cfg *Config } func (m *middleware) BeforeModelRewriteState(ctx context.Context, state *adk.ChatModelAgentState, mtx *adk.ModelContext) (context.Context, *adk.ChatModelAgentState, error) { var tools []*schema.ToolInfo if mtx != nil { tools = mtx.Tools } triggered, err := m.shouldSummarize(ctx, &TokenCounterInput{ Messages: state.Messages, Tools: tools, }) if err != nil { return nil, nil, err } if !triggered { return ctx, state, nil } beforeState := *state if m.cfg.EmitInternalEvents { err = m.emitEvent(ctx, &CustomizedAction{ Type: ActionTypeBeforeSummarize, Before: &BeforeSummarizeAction{Messages: state.Messages}, }) if err != nil { return nil, nil, err } } var ( systemMsgs []adk.Message contextMsgs []adk.Message ) for _, msg := range state.Messages { if msg.Role == schema.System { systemMsgs = append(systemMsgs, msg) } else { contextMsgs = append(contextMsgs, msg) } } summary, err := m.summarize(ctx, state.Messages, contextMsgs) if err != nil { return nil, nil, err } summary, err = m.postProcessSummary(ctx, contextMsgs, summary) if err != nil { return nil, nil, err } if m.cfg.Finalize != nil { state.Messages, err = m.cfg.Finalize(ctx, state.Messages, summary) if err != nil { return nil, nil, err } } else { state.Messages = append(systemMsgs, summary) } if m.cfg.Callback != nil { err = m.cfg.Callback(ctx, beforeState, *state) if err != nil { return nil, nil, err } } if m.cfg.EmitInternalEvents { err = m.emitEvent(ctx, &CustomizedAction{ Type: ActionTypeAfterSummarize, After: &AfterSummarizeAction{Messages: state.Messages}, }) if err != nil { return nil, nil, err } } return ctx, state, nil } func (m *middleware) shouldSummarize(ctx context.Context, input *TokenCounterInput) (bool, error) { if m.cfg.Trigger != nil && m.cfg.Trigger.ContextMessages > 0 { if len(input.Messages) > m.cfg.Trigger.ContextMessages { return true, nil } } tokens, err := m.countTokens(ctx, input) if err != nil { return false, fmt.Errorf("failed to count tokens: %w", err) } return tokens > m.getTriggerContextTokens(), nil } func (m *middleware) getTriggerContextTokens() int { const defaultTriggerContextTokens = 190000 if m.cfg.Trigger != nil { return m.cfg.Trigger.ContextTokens } return defaultTriggerContextTokens } func (m *middleware) getUserMessageContextTokens() int { if m.cfg.PreserveUserMessages != nil && m.cfg.PreserveUserMessages.MaxTokens > 0 { return m.cfg.PreserveUserMessages.MaxTokens } return m.getTriggerContextTokens() / 3 } func (m *middleware) emitEvent(ctx context.Context, action *CustomizedAction) error { err := adk.SendEvent(ctx, &adk.AgentEvent{ Action: &adk.AgentAction{ CustomizedAction: action, }, }) if err != nil { return fmt.Errorf("failed to send internal event: %w", err) } return nil } func (m *middleware) countTokens(ctx context.Context, input *TokenCounterInput) (int, error) { if m.cfg.TokenCounter != nil { return m.cfg.TokenCounter(ctx, input) } return defaultTokenCounter(ctx, input) } func defaultTokenCounter(ctx context.Context, input *TokenCounterInput) (int, error) { var totalTokens int for _, msg := range input.Messages { text := extractTextContent(msg) totalTokens += estimateTokenCount(text) } for _, tl := range input.Tools { tl_ := *tl tl_.Extra = nil text, err := sonic.MarshalString(tl_) if err != nil { return 0, fmt.Errorf("failed to marshal tool info: %w", err) } totalTokens += estimateTokenCount(text) } return totalTokens, nil } func estimateTokenCount(text string) int { return (len(text) + 3) / 4 } func (m *middleware) summarize(ctx context.Context, originMsgs, contextMsgs []adk.Message) (adk.Message, error) { input, err := m.buildSummarizationModelInput(ctx, originMsgs, contextMsgs) if err != nil { return nil, err } resp, err := m.cfg.Model.Generate(ctx, input, m.cfg.ModelOptions...) if err != nil { return nil, fmt.Errorf("failed to generate summary: %w", err) } return newSummaryMessage(resp.Content), nil } func (m *middleware) buildSummarizationModelInput(ctx context.Context, originMsgs, contextMsgs []adk.Message) ([]adk.Message, error) { userInstruction := m.cfg.UserInstruction if userInstruction == "" { userInstruction = getUserSummaryInstruction() } userInstructionMsg := &schema.Message{ Role: schema.User, Content: userInstruction, } sysInstructionMsg := &schema.Message{ Role: schema.System, Content: getSystemInstruction(), } if m.cfg.GenModelInput != nil { input, err := m.cfg.GenModelInput(ctx, sysInstructionMsg, userInstructionMsg, originMsgs) if err != nil { return nil, fmt.Errorf("failed to generate model input: %w", err) } return input, nil } input := make([]adk.Message, 0, len(contextMsgs)+2) input = append(input, sysInstructionMsg) input = append(input, contextMsgs...) input = append(input, userInstructionMsg) return input, nil } func newSummaryMessage(content string) *schema.Message { summary := &schema.Message{ Role: schema.User, Content: content, } setContentType(summary, contentTypeSummary) return summary } func (m *middleware) postProcessSummary(ctx context.Context, messages []adk.Message, summary adk.Message) (adk.Message, error) { if m.cfg.PreserveUserMessages == nil || m.cfg.PreserveUserMessages.Enabled { maxUserMsgTokens := m.getUserMessageContextTokens() content, err := m.replaceUserMessagesInSummary(ctx, messages, summary.Content, maxUserMsgTokens) if err != nil { return nil, fmt.Errorf("failed to replace user messages in summary: %w", err) } summary.Content = content } if path := m.cfg.TranscriptFilePath; path != "" { summary.Content = appendSection(summary.Content, fmt.Sprintf(getTranscriptPathInstruction(), path)) } summary.Content = appendSection(getSummaryPreamble(), summary.Content) summary.UserInputMultiContent = []schema.MessageInputPart{ { Type: schema.ChatMessagePartTypeText, Text: summary.Content, }, { Type: schema.ChatMessagePartTypeText, Text: getContinueInstruction(), }, } summary.Content = "" return summary, nil } func (m *middleware) replaceUserMessagesInSummary(ctx context.Context, messages []adk.Message, summary string, contextTokens int) (string, error) { var userMsgs []adk.Message for _, msg := range messages { if typ, ok := getContentType(msg); ok && typ == contentTypeSummary { continue } if msg.Role == schema.User { userMsgs = append(userMsgs, msg) } } if len(userMsgs) == 0 { return summary, nil } var selected []adk.Message if len(userMsgs) == 1 { selected = userMsgs } else { var totalTokens int for i := len(userMsgs) - 1; i >= 0; i-- { msg := userMsgs[i] tokens, err := m.countTokens(ctx, &TokenCounterInput{ Messages: []adk.Message{msg}, }) if err != nil { return "", fmt.Errorf("failed to count tokens: %w", err) } remaining := contextTokens - totalTokens if tokens <= remaining { totalTokens += tokens selected = append(selected, msg) continue } trimmedMsg := defaultTrimUserMessage(msg, remaining) if trimmedMsg != nil { selected = append(selected, trimmedMsg) } break } for i, j := 0, len(selected)-1; i < j; i, j = i+1, j-1 { selected[i], selected[j] = selected[j], selected[i] } } var msgLines []string for _, msg := range selected { text := extractTextContent(msg) if text != "" { msgLines = append(msgLines, " - "+text) } } userMsgsText := strings.Join(msgLines, "\n") if userMsgsText == "" { return summary, nil } lastMatch := findLastMatch(allUserMessagesTagRegex, summary) if lastMatch == nil { return summary, nil } var replacement string if len(selected) < len(userMsgs) { replacement = "\n" + getUserMessagesReplacedNote() + "\n" + userMsgsText + "\n" } else { replacement = "\n" + userMsgsText + "\n" } content := summary[:lastMatch[0]] + replacement + summary[lastMatch[1]:] return content, nil } func findLastMatch(re *regexp.Regexp, s string) []int { matches := re.FindAllStringIndex(s, -1) if len(matches) == 0 { return nil } return matches[len(matches)-1] } func appendSection(base, section string) string { if base == "" { return section } if section == "" { return base } return base + "\n\n" + section } func defaultTrimUserMessage(msg adk.Message, remainingTokens int) adk.Message { if remainingTokens <= 0 { return nil } textContent := extractTextContent(msg) if len(textContent) == 0 { return nil } trimmed := truncateTextByChars(textContent) if trimmed == "" { return nil } return &schema.Message{ Role: schema.User, Content: trimmed, } } func truncateTextByChars(text string) string { const maxRunes = 2000 if text == "" { return "" } if utf8.RuneCountInString(text) <= maxRunes { return text } halfRunes := maxRunes / 2 runes := []rune(text) totalRunes := len(runes) prefix := string(runes[:halfRunes]) suffix := string(runes[totalRunes-halfRunes:]) removedChars := totalRunes - maxRunes marker := fmt.Sprintf(getTruncatedMarkerFormat(), removedChars) return prefix + marker + suffix } func extractTextContent(msg adk.Message) string { if msg == nil { return "" } if msg.Content != "" { return msg.Content } var sb strings.Builder for _, part := range msg.UserInputMultiContent { if part.Type == schema.ChatMessagePartTypeText && part.Text != "" { if sb.Len() > 0 { sb.WriteString("\n") } sb.WriteString(part.Text) } } return sb.String() } func (c *Config) check() error { if c == nil { return fmt.Errorf("config is required") } if c.Model == nil { return fmt.Errorf("model is required") } if c.Trigger != nil { if err := c.Trigger.check(); err != nil { return err } } return nil } func (c *TriggerCondition) check() error { if c.ContextTokens < 0 { return fmt.Errorf("trigger.ContextTokens must be non-negative") } if c.ContextMessages < 0 { return fmt.Errorf("trigger.ContextMessages must be non-negative") } if c.ContextTokens == 0 && c.ContextMessages == 0 { return fmt.Errorf("at least one of trigger.ContextTokens or trigger.ContextMessages must be non-negative") } return nil } func setContentType(msg adk.Message, ct summarizationContentType) { setExtra(msg, extraKeyContentType, string(ct)) } func getContentType(msg adk.Message) (summarizationContentType, bool) { ct, ok := getExtra[string](msg, extraKeyContentType) if !ok { return "", false } return summarizationContentType(ct), true } func setExtra(msg adk.Message, key string, value any) { if msg.Extra == nil { msg.Extra = make(map[string]any) } msg.Extra[key] = value } func getExtra[T any](msg adk.Message, key string) (T, bool) { var zero T if msg == nil || msg.Extra == nil { return zero, false } v, ok := msg.Extra[key].(T) if !ok { return zero, false } return v, true } ================================================ FILE: adk/middlewares/summarization/summarization_test.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package summarization import ( "context" "errors" "strings" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/adk" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) func TestNew(t *testing.T) { ctx := context.Background() t.Run("valid config", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cfg := &Config{ Model: cm, } mw, err := New(ctx, cfg) assert.NoError(t, err) assert.NotNil(t, mw) }) t.Run("nil config returns error", func(t *testing.T) { mw, err := New(ctx, nil) assert.Error(t, err) assert.Nil(t, mw) }) t.Run("nil model returns error", func(t *testing.T) { mw, err := New(ctx, &Config{}) assert.Error(t, err) assert.Nil(t, mw) }) } func TestMiddlewareBeforeModelRewriteState(t *testing.T) { ctx := context.Background() mtx := &adk.ModelContext{} t.Run("no summarization when under threshold", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) mw := &middleware{ cfg: &Config{ Model: cm, Trigger: &TriggerCondition{ContextTokens: 1000}, }, BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, } state := &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage("hello"), schema.AssistantMessage("hi", nil), }, } _, newState, err := mw.BeforeModelRewriteState(ctx, state, mtx) assert.NoError(t, err) assert.Len(t, newState.Messages, 2) assert.Equal(t, "hello", newState.Messages[0].Content) }) t.Run("summarization triggered when over threshold", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(&schema.Message{ Role: schema.Assistant, Content: "Summary content", }, nil).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, Trigger: &TriggerCondition{ContextTokens: 10}, }, BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, } state := &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage(strings.Repeat("a", 100)), schema.AssistantMessage(strings.Repeat("b", 100), nil), }, } _, newState, err := mw.BeforeModelRewriteState(ctx, state, mtx) assert.NoError(t, err) assert.Len(t, newState.Messages, 1) assert.Equal(t, schema.User, newState.Messages[0].Role) }) t.Run("preserves system messages after summarization", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { for i, msg := range msgs { if i == 0 { assert.Equal(t, schema.System, msg.Role) } else { assert.NotEqual(t, schema.System, msg.Role) } } return &schema.Message{ Role: schema.Assistant, Content: "Summary content", }, nil }).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, Trigger: &TriggerCondition{ContextTokens: 10}, }, BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, } state := &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.SystemMessage("You are a helpful assistant"), schema.UserMessage(strings.Repeat("a", 100)), schema.AssistantMessage(strings.Repeat("b", 100), nil), }, } _, newState, err := mw.BeforeModelRewriteState(ctx, state, mtx) assert.NoError(t, err) assert.Len(t, newState.Messages, 2) assert.Equal(t, schema.System, newState.Messages[0].Role) assert.Equal(t, "You are a helpful assistant", newState.Messages[0].Content) assert.Equal(t, schema.User, newState.Messages[1].Role) }) t.Run("preserves multiple system messages", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(&schema.Message{ Role: schema.Assistant, Content: "Summary", }, nil).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, Trigger: &TriggerCondition{ContextTokens: 10}, }, BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, } state := &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.SystemMessage("System 1"), schema.SystemMessage("System 2"), schema.UserMessage(strings.Repeat("a", 100)), }, } _, newState, err := mw.BeforeModelRewriteState(ctx, state, mtx) assert.NoError(t, err) assert.Len(t, newState.Messages, 3) assert.Equal(t, schema.System, newState.Messages[0].Role) assert.Equal(t, "System 1", newState.Messages[0].Content) assert.Equal(t, schema.System, newState.Messages[1].Role) assert.Equal(t, "System 2", newState.Messages[1].Content) assert.Equal(t, schema.User, newState.Messages[2].Role) }) t.Run("custom finalize function", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(&schema.Message{ Role: schema.Assistant, Content: "Summary", }, nil).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, Trigger: &TriggerCondition{ContextTokens: 10}, Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) { return []adk.Message{ schema.SystemMessage("system prompt"), summary, }, nil }, }, BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, } state := &adk.ChatModelAgentState{ Messages: []adk.Message{ schema.UserMessage(strings.Repeat("a", 100)), }, } _, newState, err := mw.BeforeModelRewriteState(ctx, state, mtx) assert.NoError(t, err) assert.Len(t, newState.Messages, 2) assert.Equal(t, schema.System, newState.Messages[0].Role) assert.Equal(t, "system prompt", newState.Messages[0].Content) }) } func TestMiddlewareShouldSummarize(t *testing.T) { ctx := context.Background() t.Run("returns true when over messages threshold", func(t *testing.T) { mw := &middleware{ cfg: &Config{ Trigger: &TriggerCondition{ContextMessages: 1}, }, } input := &TokenCounterInput{ Messages: []adk.Message{ schema.UserMessage("msg1"), schema.UserMessage("msg2"), }, } triggered, err := mw.shouldSummarize(ctx, input) assert.NoError(t, err) assert.True(t, triggered) }) t.Run("returns false when under messages threshold", func(t *testing.T) { mw := &middleware{ cfg: &Config{ Trigger: &TriggerCondition{ ContextMessages: 3, ContextTokens: 1000, }, }, } input := &TokenCounterInput{ Messages: []adk.Message{ schema.UserMessage("msg1"), schema.UserMessage("msg2"), }, } triggered, err := mw.shouldSummarize(ctx, input) assert.NoError(t, err) assert.False(t, triggered) }) t.Run("returns true when over threshold", func(t *testing.T) { mw := &middleware{ cfg: &Config{ Trigger: &TriggerCondition{ContextTokens: 10}, }, } input := &TokenCounterInput{ Messages: []adk.Message{ schema.UserMessage(strings.Repeat("a", 100)), }, } triggered, err := mw.shouldSummarize(ctx, input) assert.NoError(t, err) assert.True(t, triggered) }) t.Run("returns false when under threshold", func(t *testing.T) { mw := &middleware{ cfg: &Config{ Trigger: &TriggerCondition{ContextTokens: 1000}, }, } input := &TokenCounterInput{ Messages: []adk.Message{ schema.UserMessage("short message"), }, } triggered, err := mw.shouldSummarize(ctx, input) assert.NoError(t, err) assert.False(t, triggered) }) t.Run("uses default threshold when trigger is nil", func(t *testing.T) { mw := &middleware{ cfg: &Config{}, } input := &TokenCounterInput{ Messages: []adk.Message{ schema.UserMessage("short message"), }, } triggered, err := mw.shouldSummarize(ctx, input) assert.NoError(t, err) assert.False(t, triggered) }) } func TestMiddlewareCountTokens(t *testing.T) { ctx := context.Background() t.Run("uses custom token counter", func(t *testing.T) { mw := &middleware{ cfg: &Config{ TokenCounter: func(ctx context.Context, input *TokenCounterInput) (int, error) { return 42, nil }, }, } input := &TokenCounterInput{ Messages: []adk.Message{schema.UserMessage("test")}, } tokens, err := mw.countTokens(ctx, input) assert.NoError(t, err) assert.Equal(t, 42, tokens) }) t.Run("uses default token counter when nil", func(t *testing.T) { mw := &middleware{ cfg: &Config{}, } input := &TokenCounterInput{ Messages: []adk.Message{schema.UserMessage("test")}, } tokens, err := mw.countTokens(ctx, input) assert.NoError(t, err) assert.Equal(t, 1, tokens) }) t.Run("custom token counter error", func(t *testing.T) { mw := &middleware{ cfg: &Config{ TokenCounter: func(ctx context.Context, input *TokenCounterInput) (int, error) { return 0, errors.New("token count error") }, }, } input := &TokenCounterInput{ Messages: []adk.Message{schema.UserMessage("test")}, } _, err := mw.countTokens(ctx, input) assert.Error(t, err) }) } func TestExtractTextContent(t *testing.T) { t.Run("extracts from Content field", func(t *testing.T) { msg := &schema.Message{ Role: schema.User, Content: "hello world", } assert.Equal(t, "hello world", extractTextContent(msg)) }) t.Run("extracts from UserInputMultiContent", func(t *testing.T) { msg := &schema.Message{ Role: schema.User, UserInputMultiContent: []schema.MessageInputPart{ {Type: schema.ChatMessagePartTypeText, Text: "part1"}, {Type: schema.ChatMessagePartTypeText, Text: "part2"}, }, } assert.Equal(t, "part1\npart2", extractTextContent(msg)) }) t.Run("prefers Content over UserInputMultiContent", func(t *testing.T) { msg := &schema.Message{ Role: schema.User, Content: "content field", UserInputMultiContent: []schema.MessageInputPart{ {Type: schema.ChatMessagePartTypeText, Text: "multi content"}, }, } assert.Equal(t, "content field", extractTextContent(msg)) }) } func TestTruncateTextByChars(t *testing.T) { t.Run("returns empty for empty string", func(t *testing.T) { result := truncateTextByChars("") assert.Equal(t, "", result) }) t.Run("returns original if under limit", func(t *testing.T) { result := truncateTextByChars("short") assert.Equal(t, "short", result) }) t.Run("truncates long text", func(t *testing.T) { longText := strings.Repeat("a", 3000) result := truncateTextByChars(longText) assert.Less(t, len(result), len(longText)) assert.Contains(t, result, "truncated") }) t.Run("preserves prefix and suffix", func(t *testing.T) { longText := strings.Repeat("a", 1000) + strings.Repeat("b", 1000) + strings.Repeat("c", 1000) result := truncateTextByChars(longText) assert.True(t, strings.HasPrefix(result, strings.Repeat("a", 1000))) assert.True(t, strings.HasSuffix(result, strings.Repeat("c", 1000))) }) } func TestAppendSection(t *testing.T) { tests := []struct { name string base string section string expected string }{ { name: "both empty", base: "", section: "", expected: "", }, { name: "base empty", base: "", section: "section", expected: "section", }, { name: "section empty", base: "base", section: "", expected: "base", }, { name: "both non-empty", base: "base", section: "section", expected: "base\n\nsection", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := appendSection(tt.base, tt.section) assert.Equal(t, tt.expected, result) }) } } func TestAllUserMessagesTagRegex(t *testing.T) { t.Run("matches tag", func(t *testing.T) { text := ` - msg1 - msg2 ` assert.True(t, allUserMessagesTagRegex.MatchString(text)) }) t.Run("replaces tag content", func(t *testing.T) { text := `before - old msg after` replacement := "\n - new msg\n" result := allUserMessagesTagRegex.ReplaceAllString(text, replacement) assert.Contains(t, result, "new msg") assert.NotContains(t, result, "old msg") assert.Contains(t, result, "before") assert.Contains(t, result, "after") }) } func TestConfigCheck(t *testing.T) { t.Run("nil config", func(t *testing.T) { var c *Config err := c.check() assert.Error(t, err) assert.Contains(t, err.Error(), "config is required") }) t.Run("nil model", func(t *testing.T) { c := &Config{} err := c.check() assert.Error(t, err) assert.Contains(t, err.Error(), "model is required") }) t.Run("valid config", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) c := &Config{ Model: cm, } err := c.check() assert.NoError(t, err) }) t.Run("invalid trigger max tokens", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) c := &Config{ Model: cm, Trigger: &TriggerCondition{ContextTokens: -1}, } err := c.check() assert.Error(t, err) assert.Contains(t, err.Error(), "must be non-negative") }) t.Run("invalid trigger max messages", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) c := &Config{ Model: cm, Trigger: &TriggerCondition{ContextMessages: -1}, } err := c.check() assert.Error(t, err) assert.Contains(t, err.Error(), "must be non-negative") }) t.Run("both trigger conditions are zero", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) c := &Config{ Model: cm, Trigger: &TriggerCondition{ContextTokens: 0, ContextMessages: 0}, } err := c.check() assert.Error(t, err) assert.Contains(t, err.Error(), "must be non-negative") }) } func TestSetGetContentType(t *testing.T) { msg := &schema.Message{ Role: schema.User, Content: "test", } setContentType(msg, contentTypeSummary) ct, ok := getContentType(msg) assert.True(t, ok) assert.Equal(t, contentTypeSummary, ct) } func TestSetGetExtra(t *testing.T) { t.Run("set and get", func(t *testing.T) { msg := &schema.Message{ Role: schema.User, Content: "test", } setExtra(msg, "key", "value") v, ok := getExtra[string](msg, "key") assert.True(t, ok) assert.Equal(t, "value", v) }) t.Run("get from nil message", func(t *testing.T) { v, ok := getExtra[string](nil, "key") assert.False(t, ok) assert.Equal(t, "", v) }) t.Run("get non-existent key", func(t *testing.T) { msg := &schema.Message{ Role: schema.User, Content: "test", } v, ok := getExtra[string](msg, "non-existent") assert.False(t, ok) assert.Equal(t, "", v) }) } func TestMiddlewareSummarize(t *testing.T) { ctx := context.Background() t.Run("message structure", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { assert.GreaterOrEqual(t, len(msgs), 3) assert.Equal(t, schema.System, msgs[0].Role) assert.Equal(t, schema.User, msgs[len(msgs)-1].Role) return &schema.Message{ Role: schema.Assistant, Content: "summary", }, nil }).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, }, } testMsg := []adk.Message{schema.UserMessage("test")} _, err := mw.summarize(ctx, testMsg, testMsg) assert.NoError(t, err) }) t.Run("uses context messages", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { // Verify the context messages are included found := false for _, msg := range msgs { if msg.Content == "context message" { found = true break } } assert.True(t, found, "should contain context message") return &schema.Message{ Role: schema.Assistant, Content: "summary", }, nil }).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, }, } contextMsgs := []adk.Message{ schema.UserMessage("context message"), } _, err := mw.summarize(ctx, contextMsgs, contextMsgs) assert.NoError(t, err) }) t.Run("uses GenModelInput", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) expectedInput := []adk.Message{ schema.UserMessage("custom input"), } cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { assert.Len(t, msgs, 1) assert.Equal(t, "custom input", msgs[0].Content) return &schema.Message{ Role: schema.Assistant, Content: "summary", }, nil }).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, GenModelInput: func(ctx context.Context, defaultSystemInstruction, userInstruction adk.Message, originalMsgs []adk.Message) ([]adk.Message, error) { return expectedInput, nil }, }, } testMsg := []adk.Message{schema.UserMessage("test")} _, err := mw.summarize(ctx, testMsg, testMsg) assert.NoError(t, err) }) t.Run("GenModelInput error", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) mw := &middleware{ cfg: &Config{ Model: cm, GenModelInput: func(ctx context.Context, defaultSystemInstruction, userInstruction adk.Message, originalMsgs []adk.Message) ([]adk.Message, error) { return nil, errors.New("gen input error") }, }, } testMsg := []adk.Message{schema.UserMessage("test")} _, err := mw.summarize(ctx, testMsg, testMsg) assert.Error(t, err) assert.Contains(t, err.Error(), "gen input error") }) t.Run("uses custom instruction", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { lastMsg := msgs[len(msgs)-1] assert.Equal(t, schema.User, lastMsg.Role) assert.Contains(t, lastMsg.Content, "custom instruction") return &schema.Message{ Role: schema.Assistant, Content: "summary", }, nil }).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, UserInstruction: "custom instruction", }, } testMsg := []adk.Message{schema.UserMessage("test")} _, err := mw.summarize(ctx, testMsg, testMsg) assert.NoError(t, err) }) t.Run("model generate error", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockBaseChatModel(ctrl) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errors.New("generate error")).Times(1) mw := &middleware{ cfg: &Config{ Model: cm, }, } testMsg := []adk.Message{schema.UserMessage("test")} _, err := mw.summarize(ctx, testMsg, testMsg) assert.Error(t, err) }) } func TestReplaceUserMessagesInSummary(t *testing.T) { ctx := context.Background() t.Run("replaces user messages section", func(t *testing.T) { mw := &middleware{ cfg: &Config{}, } msgs := []adk.Message{ schema.UserMessage("msg1"), schema.AssistantMessage("response1", nil), schema.UserMessage("msg2"), } summary := `1. Primary Request: test 6. All user messages: - [old message] 7. Pending Tasks: - task1` result, err := mw.replaceUserMessagesInSummary(ctx, msgs, summary, 1000) assert.NoError(t, err) assert.Contains(t, result, "msg1") assert.Contains(t, result, "msg2") assert.NotContains(t, result, "old message") assert.Contains(t, result, "7. Pending Tasks:") }) t.Run("returns original if no matching sections", func(t *testing.T) { mw := &middleware{ cfg: &Config{}, } msgs := []adk.Message{ schema.UserMessage("test"), } summary := "summary without sections" result, err := mw.replaceUserMessagesInSummary(ctx, msgs, summary, 1000) assert.NoError(t, err) assert.Equal(t, summary, result) }) t.Run("skips summary messages", func(t *testing.T) { mw := &middleware{ cfg: &Config{}, } summaryMsg := &schema.Message{ Role: schema.User, Content: "summary", } setContentType(summaryMsg, contentTypeSummary) msgs := []adk.Message{ summaryMsg, schema.UserMessage("regular message"), } summary := `6. All user messages: - [old] 7. Pending Tasks: - task` result, err := mw.replaceUserMessagesInSummary(ctx, msgs, summary, 1000) assert.NoError(t, err) assert.Contains(t, result, "regular message") assert.NotContains(t, result, " - summary") }) t.Run("token counter error", func(t *testing.T) { mw := &middleware{ cfg: &Config{ TokenCounter: func(ctx context.Context, input *TokenCounterInput) (int, error) { return 0, errors.New("count error") }, }, } msgs := []adk.Message{ schema.UserMessage("test1"), schema.UserMessage("test2"), } _, err := mw.replaceUserMessagesInSummary(ctx, msgs, "summary", 1000) assert.Error(t, err) }) t.Run("returns original if empty user messages", func(t *testing.T) { mw := &middleware{ cfg: &Config{}, } msgs := []adk.Message{ schema.AssistantMessage("response", nil), } summary := `6. All user messages: - [old] 7. Pending Tasks: - task` result, err := mw.replaceUserMessagesInSummary(ctx, msgs, summary, 1000) assert.NoError(t, err) assert.Equal(t, summary, result) }) } func TestAllUserMessagesTagRegexMatch(t *testing.T) { t.Run("matches xml tag", func(t *testing.T) { text := "\n - msg\n" assert.True(t, allUserMessagesTagRegex.MatchString(text)) }) t.Run("does not match without tag", func(t *testing.T) { text := "6. All user messages:\n - msg" assert.False(t, allUserMessagesTagRegex.MatchString(text)) }) } func TestDefaultTrimUserMessage(t *testing.T) { t.Run("returns nil for zero remaining tokens", func(t *testing.T) { msg := schema.UserMessage("test") result := defaultTrimUserMessage(msg, 0) assert.Nil(t, result) }) t.Run("returns nil for empty content", func(t *testing.T) { msg := schema.UserMessage("") result := defaultTrimUserMessage(msg, 100) assert.Nil(t, result) }) t.Run("trims long message", func(t *testing.T) { longText := strings.Repeat("a", 3000) msg := schema.UserMessage(longText) result := defaultTrimUserMessage(msg, 100) assert.NotNil(t, result) assert.Less(t, len(result.Content), len(longText)) }) } func TestDefaultTokenCounter(t *testing.T) { ctx := context.Background() t.Run("counts tool tokens", func(t *testing.T) { input := &TokenCounterInput{ Messages: []adk.Message{}, Tools: []*schema.ToolInfo{ {Name: "test_tool", Desc: "description"}, }, } count, err := defaultTokenCounter(ctx, input) assert.NoError(t, err) assert.Greater(t, count, 0) }) } func TestPostProcessSummary(t *testing.T) { ctx := context.Background() t.Run("with transcript path", func(t *testing.T) { mw := &middleware{ cfg: &Config{ TranscriptFilePath: "/path/to/transcript.txt", }, } summary := &schema.Message{ Role: schema.User, Content: "summary content", } result, err := mw.postProcessSummary(ctx, []adk.Message{}, summary) assert.NoError(t, err) assert.Len(t, result.UserInputMultiContent, 2) assert.Contains(t, result.UserInputMultiContent[0].Text, "/path/to/transcript.txt") }) } ================================================ FILE: adk/prebuilt/deep/checkpoint_compat_resume_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package deep import ( "context" "os" "path/filepath" "runtime" "strings" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) type compatCheckpointStore struct { data map[string][]byte } func newCompatCheckpointStore() *compatCheckpointStore { return &compatCheckpointStore{data: make(map[string][]byte)} } func (s *compatCheckpointStore) Set(_ context.Context, key string, value []byte) error { s.data[key] = append([]byte(nil), value...) return nil } func (s *compatCheckpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { v, ok := s.data[key] if !ok { return nil, false, nil } return append([]byte(nil), v...), true, nil } type interruptingSubAgentTool struct { name string } func (t *interruptingSubAgentTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: t.name, Desc: "interrupts on first call and resumes from stored state", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "action": {Type: schema.String}, }), }, nil } func (t *interruptingSubAgentTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { _, _, _ = tool.GetInterruptState[string](ctx) return "resumed", nil } func readTestdataBytes(t *testing.T, filename string) []byte { t.Helper() _, file, _, ok := runtime.Caller(0) assert.True(t, ok) p := filepath.Join(filepath.Dir(file), "testdata", filename) b, err := os.ReadFile(p) assert.NoError(t, err) assert.NotEmpty(t, b) return b } func runDeepAgentCheckpointCompat(t *testing.T, checkpointID string, filename string) { t.Helper() ctx := context.Background() data := readTestdataBytes(t, filename) store := newCompatCheckpointStore() assert.NoError(t, store.Set(ctx, checkpointID, data)) ctrl := gomock.NewController(t) defer ctrl.Finish() interruptToolName := "interrupt_in_subagent_tool" subTool := &interruptingSubAgentTool{name: interruptToolName} deepModel := mockModel.NewMockBaseChatModel(ctrl) subModel := mockModel.NewMockBaseChatModel(ctrl) deepModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { return schema.AssistantMessage("deep done", nil), nil }).AnyTimes() subModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { return schema.AssistantMessage("sub done", nil), nil }).AnyTimes() subAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "sub_chatmodel_agent", Description: "sub agent", Model: subModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{subTool}, }, }, MaxIterations: 4, }) assert.NoError(t, err) deepAgent, err := New(ctx, &Config{ Name: "deep", Description: "deep agent", ChatModel: deepModel, SubAgents: []adk.Agent{subAgent}, MaxIteration: 4, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, }) assert.NoError(t, err) runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: deepAgent, CheckPointStore: store, }) it, err := runner.Resume(ctx, checkpointID) assert.NoError(t, err) var sawDeepDone bool var sawAnyOutput bool for { ev, ok := it.Next() if !ok { break } assert.NoError(t, ev.Err) if ev.Output != nil && ev.Output.MessageOutput != nil && ev.Output.MessageOutput.Message != nil { sawAnyOutput = true msg := ev.Output.MessageOutput.Message if msg.Role == schema.Assistant && strings.Contains(msg.Content, "deep done") { sawDeepDone = true } } } assert.True(t, sawAnyOutput) assert.True(t, sawDeepDone) } func TestDeepAgentCheckpointCompat_V0_8_Resume(t *testing.T) { tests := []struct { name string checkpointID string filename string }{ { name: "v0.7.37", checkpointID: "checkpoint_compat_v0_7_37", filename: "checkpoint_data_v0.7.37.bin", }, { name: "v0.8.2", checkpointID: "checkpoint_compat_v0_8_2", filename: "checkpoint_data_v0.8.2.bin", }, { name: "v0.8.3", checkpointID: "checkpoint_compat_v0_8_3", filename: "checkpoint_data_v0.8.3.bin", }, { name: "v0.8.4", checkpointID: "checkpoint_compat_v0_8_4", filename: "checkpoint_data_v0.8.4.bin", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { runDeepAgentCheckpointCompat(t, tc.checkpointID, tc.filename) }) } } ================================================ FILE: adk/prebuilt/deep/deep.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package deep provides a prebuilt agent with deep task orchestration. package deep import ( "context" "fmt" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/adk/internal" filesystem2 "github.com/cloudwego/eino/adk/middlewares/filesystem" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool/utils" "github.com/cloudwego/eino/schema" ) func init() { schema.RegisterName[TODO]("_eino_adk_prebuilt_deep_todo") schema.RegisterName[[]TODO]("_eino_adk_prebuilt_deep_todo_slice") } // Config defines the configuration for creating a DeepAgent. type Config struct { // Name is the identifier for the Deep agent. Name string // Description provides a brief explanation of the agent's purpose. Description string // ChatModel is the model used by DeepAgent for reasoning and task execution. // If the agent uses any tools, this model must support the model.WithTools call option, // as that's how the agent configures the model with tool information. ChatModel model.BaseChatModel // Instruction contains the system prompt that guides the agent's behavior. // When empty, a built-in default system prompt will be used, which includes general assistant // behavior guidelines, security policies, coding style guidelines, and tool usage policies. Instruction string // SubAgents are specialized agents that can be invoked by the agent. SubAgents []adk.Agent // ToolsConfig provides the tools and tool-calling configurations available for the agent to invoke. ToolsConfig adk.ToolsConfig // MaxIteration limits the maximum number of reasoning iterations the agent can perform. MaxIteration int // Backend provides filesystem operations used by tools and offloading. // If set, filesystem tools (read_file, write_file, edit_file, glob, grep) will be registered. // Optional. Backend filesystem.Backend // Shell provides shell command execution capability. // If set, an execute tool will be registered to support shell command execution. // Optional. Mutually exclusive with StreamingShell. Shell filesystem.Shell // StreamingShell provides streaming shell command execution capability. // If set, a streaming execute tool will be registered to support streaming shell command execution. // Optional. Mutually exclusive with Shell. StreamingShell filesystem.StreamingShell // WithoutWriteTodos disables the built-in write_todos tool when set to true. WithoutWriteTodos bool // WithoutGeneralSubAgent disables the general-purpose subagent when set to true. WithoutGeneralSubAgent bool // TaskToolDescriptionGenerator allows customizing the description for the task tool. // If provided, this function generates the tool description based on available subagents. TaskToolDescriptionGenerator func(ctx context.Context, availableAgents []adk.Agent) (string, error) Middlewares []adk.AgentMiddleware // Handlers configures interface-based handlers for extending agent behavior. // Unlike Middlewares (struct-based), Handlers allow users to: // - Add custom methods to their handler implementations // - Return modified context from handler methods // - Centralize configuration in struct fields instead of closures // // Handlers are processed after Middlewares, in registration order. // See adk.ChatModelAgentMiddleware documentation for when to use Handlers vs Middlewares. Handlers []adk.ChatModelAgentMiddleware ModelRetryConfig *adk.ModelRetryConfig // OutputKey stores the agent's response in the session. // Optional. When set, stores output via AddSessionValue(ctx, outputKey, msg.Content). OutputKey string } // New creates a new Deep agent instance with the provided configuration. // This function initializes built-in tools, creates a task tool for subagent orchestration, // and returns a fully configured ChatModelAgent ready for execution. func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { handlers, err := buildBuiltinAgentMiddlewares(ctx, cfg) if err != nil { return nil, err } instruction := cfg.Instruction if len(instruction) == 0 { instruction = internal.SelectPrompt(internal.I18nPrompts{ English: baseAgentInstruction, Chinese: baseAgentInstructionChinese, }) } if !cfg.WithoutGeneralSubAgent || len(cfg.SubAgents) > 0 { tt, err := newTaskToolMiddleware( ctx, cfg.TaskToolDescriptionGenerator, cfg.SubAgents, cfg.WithoutGeneralSubAgent, cfg.ChatModel, instruction, cfg.ToolsConfig, cfg.MaxIteration, cfg.Middlewares, append(handlers, cfg.Handlers...), ) if err != nil { return nil, fmt.Errorf("failed to new task tool: %w", err) } handlers = append(handlers, tt) } return adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: cfg.Name, Description: cfg.Description, Instruction: instruction, Model: cfg.ChatModel, ToolsConfig: cfg.ToolsConfig, MaxIterations: cfg.MaxIteration, Middlewares: cfg.Middlewares, Handlers: append(handlers, cfg.Handlers...), GenModelInput: genModelInput, ModelRetryConfig: cfg.ModelRetryConfig, OutputKey: cfg.OutputKey, }) } func genModelInput(ctx context.Context, instruction string, input *adk.AgentInput) ([]*schema.Message, error) { msgs := make([]*schema.Message, 0, len(input.Messages)+1) if instruction != "" { msgs = append(msgs, schema.SystemMessage(instruction)) } msgs = append(msgs, input.Messages...) return msgs, nil } func buildBuiltinAgentMiddlewares(ctx context.Context, cfg *Config) ([]adk.ChatModelAgentMiddleware, error) { var ms []adk.ChatModelAgentMiddleware if !cfg.WithoutWriteTodos { t, err := newWriteTodos() if err != nil { return nil, err } ms = append(ms, t) } if cfg.Backend != nil || cfg.Shell != nil || cfg.StreamingShell != nil { fm, err := filesystem2.New(ctx, &filesystem2.MiddlewareConfig{ Backend: cfg.Backend, Shell: cfg.Shell, StreamingShell: cfg.StreamingShell, }) if err != nil { return nil, err } ms = append(ms, fm) } return ms, nil } type TODO struct { Content string `json:"content"` ActiveForm string `json:"activeForm"` Status string `json:"status" jsonschema:"enum=pending,enum=in_progress,enum=completed"` } type writeTodosArguments struct { Todos []TODO `json:"todos"` } func newWriteTodos() (adk.ChatModelAgentMiddleware, error) { toolDesc := internal.SelectPrompt(internal.I18nPrompts{ English: writeTodosToolDescription, Chinese: writeTodosToolDescriptionChinese, }) resultMsg := internal.SelectPrompt(internal.I18nPrompts{ English: "Updated todo list to %s", Chinese: "已更新待办列表为 %s", }) t, err := utils.InferTool("write_todos", toolDesc, func(ctx context.Context, input writeTodosArguments) (output string, err error) { adk.AddSessionValue(ctx, SessionKeyTodos, input.Todos) todos, err := sonic.MarshalString(input.Todos) if err != nil { return "", err } return fmt.Sprintf(resultMsg, todos), nil }) if err != nil { return nil, err } return buildAppendPromptTool("", t), nil } ================================================ FILE: adk/prebuilt/deep/deep_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package deep import ( "context" "fmt" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/prebuilt/planexecute" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) func TestGenModelInput(t *testing.T) { ctx := context.Background() t.Run("WithInstruction", func(t *testing.T) { input := &adk.AgentInput{ Messages: []*schema.Message{ schema.UserMessage("hello"), }, } msgs, err := genModelInput(ctx, "You are a helpful assistant", input) assert.NoError(t, err) assert.Len(t, msgs, 2) assert.Equal(t, schema.System, msgs[0].Role) assert.Equal(t, "You are a helpful assistant", msgs[0].Content) assert.Equal(t, schema.User, msgs[1].Role) assert.Equal(t, "hello", msgs[1].Content) }) t.Run("WithoutInstruction", func(t *testing.T) { input := &adk.AgentInput{ Messages: []*schema.Message{ schema.UserMessage("hello"), }, } msgs, err := genModelInput(ctx, "", input) assert.NoError(t, err) assert.Len(t, msgs, 1) assert.Equal(t, schema.User, msgs[0].Role) assert.Equal(t, "hello", msgs[0].Content) }) } func TestWriteTodos(t *testing.T) { m, err := buildBuiltinAgentMiddlewares(context.Background(), &Config{WithoutWriteTodos: false}) assert.NoError(t, err) wt := m[0].(*appendPromptTool).t.(tool.InvokableTool) todos := `[{"content":"content1","activeForm":"","status":"pending"},{"content":"content2","activeForm":"","status":"pending"}]` args := fmt.Sprintf(`{"todos": %s}`, todos) result, err := wt.InvokableRun(context.Background(), args) assert.NoError(t, err) assert.Equal(t, fmt.Sprintf("Updated todo list to %s", todos), result) } func TestDeepSubAgentSharesSessionValues(t *testing.T) { ctx := context.Background() spy := &spySubAgent{} ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() calls := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { calls++ if calls == 1 { c := schema.ToolCall{ID: "id-1", Type: "function"} c.Function.Name = taskToolName c.Function.Arguments = fmt.Sprintf(`{"subagent_type":"%s","description":"from_parent"}`, spy.Name(ctx)) return schema.AssistantMessage("", []schema.ToolCall{c}), nil } return schema.AssistantMessage("done", nil), nil }).AnyTimes() agent, err := New(ctx, &Config{ Name: "deep", Description: "deep agent", ChatModel: cm, Instruction: "you are deep agent", SubAgents: []adk.Agent{spy}, ToolsConfig: adk.ToolsConfig{}, MaxIteration: 2, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, }) assert.NoError(t, err) r := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) it := r.Run(ctx, []adk.Message{schema.UserMessage("hi")}, adk.WithSessionValues(map[string]any{"parent_key": "parent_val"})) for { if _, ok := it.Next(); !ok { break } } assert.Equal(t, "parent_val", spy.seenParentValue) } func TestDeepSubAgentFollowsStreamingMode(t *testing.T) { ctx := context.Background() spy := &spyStreamingSubAgent{} ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() subName := spy.Name(ctx) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ { ID: "id-1", Type: "function", Function: schema.FunctionCall{ Name: taskToolName, Arguments: fmt.Sprintf(`{"subagent_type":"%s","description":"from_parent"}`, subName), }, }, }), }), nil). Times(1) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("done", nil), }), nil). Times(1) agent, err := New(ctx, &Config{ Name: "deep", Description: "deep agent", ChatModel: cm, Instruction: "you are deep agent", SubAgents: []adk.Agent{spy}, ToolsConfig: adk.ToolsConfig{}, MaxIteration: 2, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, }) assert.NoError(t, err) r := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent, EnableStreaming: true}) it := r.Run(ctx, []adk.Message{schema.UserMessage("hi")}) for { if _, ok := it.Next(); !ok { break } } assert.True(t, spy.seenEnableStreaming) } type spySubAgent struct { seenParentValue any } func (s *spySubAgent) Name(context.Context) string { return "spy-subagent" } func (s *spySubAgent) Description(context.Context) string { return "spy" } func (s *spySubAgent) Run(ctx context.Context, _ *adk.AgentInput, _ ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { s.seenParentValue, _ = adk.GetSessionValue(ctx, "parent_key") it, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]() gen.Send(adk.EventFromMessage(schema.AssistantMessage("ok", nil), nil, schema.Assistant, "")) gen.Close() return it } type spyStreamingSubAgent struct { seenEnableStreaming bool } func (s *spyStreamingSubAgent) Name(context.Context) string { return "spy-streaming-subagent" } func (s *spyStreamingSubAgent) Description(context.Context) string { return "spy" } func (s *spyStreamingSubAgent) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { if input != nil { s.seenEnableStreaming = input.EnableStreaming } it, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]() gen.Send(adk.EventFromMessage(schema.AssistantMessage("ok", nil), nil, schema.Assistant, "")) gen.Close() return it } func TestDeepAgentWithPlanExecuteSubAgent_InternalEventsEmitted(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() deepModel := mockModel.NewMockToolCallingChatModel(ctrl) plannerModel := mockModel.NewMockToolCallingChatModel(ctrl) executorModel := mockModel.NewMockToolCallingChatModel(ctrl) replannerModel := mockModel.NewMockToolCallingChatModel(ctrl) deepModel.EXPECT().WithTools(gomock.Any()).Return(deepModel, nil).AnyTimes() plannerModel.EXPECT().WithTools(gomock.Any()).Return(plannerModel, nil).AnyTimes() executorModel.EXPECT().WithTools(gomock.Any()).Return(executorModel, nil).AnyTimes() replannerModel.EXPECT().WithTools(gomock.Any()).Return(replannerModel, nil).AnyTimes() plannerModel.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input []*schema.Message, opts ...interface{}) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](1) go func() { defer sw.Close() planJSON := `{"steps":["step1"]}` msg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "plan_call_1", Type: "function", Function: schema.FunctionCall{ Name: "plan", Arguments: planJSON, }, }, }) sw.Send(msg, nil) }() return sr, nil }, ).Times(1) executorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { return schema.AssistantMessage("executed step1", nil), nil }, ).Times(1) replannerModel.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input []*schema.Message, opts ...interface{}) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](1) go func() { defer sw.Close() responseJSON := `{"response":"final response"}` msg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "respond_call_1", Type: "function", Function: schema.FunctionCall{ Name: "respond", Arguments: responseJSON, }, }, }) sw.Send(msg, nil) }() return sr, nil }, ).Times(1) planner, err := planexecute.NewPlanner(ctx, &planexecute.PlannerConfig{ ToolCallingChatModel: plannerModel, }) assert.NoError(t, err) executor, err := planexecute.NewExecutor(ctx, &planexecute.ExecutorConfig{ Model: executorModel, }) assert.NoError(t, err) replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{ ChatModel: replannerModel, }) assert.NoError(t, err) planExecuteAgent, err := planexecute.New(ctx, &planexecute.Config{ Planner: planner, Executor: executor, Replanner: replanner, }) assert.NoError(t, err) namedPlanExecuteAgent := &namedPlanExecuteAgent{ ResumableAgent: planExecuteAgent, name: "plan_execute_subagent", description: "a plan execute subagent", } deepModelCalls := 0 deepModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { deepModelCalls++ if deepModelCalls == 1 { c := schema.ToolCall{ID: "id-1", Type: "function"} c.Function.Name = taskToolName c.Function.Arguments = fmt.Sprintf(`{"subagent_type":"%s","description":"execute the plan"}`, namedPlanExecuteAgent.name) return schema.AssistantMessage("", []schema.ToolCall{c}), nil } return schema.AssistantMessage("done", nil), nil }).AnyTimes() deepAgent, err := New(ctx, &Config{ Name: "deep", Description: "deep agent", ChatModel: deepModel, Instruction: "you are deep agent", SubAgents: []adk.Agent{namedPlanExecuteAgent}, ToolsConfig: adk.ToolsConfig{EmitInternalEvents: true}, MaxIteration: 5, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, }) assert.NoError(t, err) r := adk.NewRunner(ctx, adk.RunnerConfig{Agent: deepAgent}) it := r.Run(ctx, []adk.Message{schema.UserMessage("hi")}) var events []*adk.AgentEvent for { event, ok := it.Next() if !ok { break } events = append(events, event) } assert.Greater(t, len(events), 0, "should have at least one event") var deepAgentEvents []*adk.AgentEvent var plannerEvents []*adk.AgentEvent var executorEvents []*adk.AgentEvent var replannerEvents []*adk.AgentEvent var planExecuteEvents []*adk.AgentEvent for _, event := range events { switch event.AgentName { case "deep": deepAgentEvents = append(deepAgentEvents, event) case "planner": plannerEvents = append(plannerEvents, event) case "executor": executorEvents = append(executorEvents, event) case "replanner": replannerEvents = append(replannerEvents, event) case "plan_execute_replan", "execute_replan": planExecuteEvents = append(planExecuteEvents, event) } } assert.Greater(t, len(deepAgentEvents), 0, "should have events from deep agent") assert.Greater(t, len(plannerEvents), 0, "planner internal events should be emitted when EmitInternalEvents is true") assert.Greater(t, len(executorEvents), 0, "executor internal events should be emitted when EmitInternalEvents is true") assert.Greater(t, len(replannerEvents), 0, "replanner internal events should be emitted when EmitInternalEvents is true") t.Logf("Total events: %d", len(events)) t.Logf("Deep agent events: %d", len(deepAgentEvents)) t.Logf("Planner events: %d", len(plannerEvents)) t.Logf("Executor events: %d", len(executorEvents)) t.Logf("Replanner events: %d", len(replannerEvents)) t.Logf("PlanExecute events: %d", len(planExecuteEvents)) } type namedPlanExecuteAgent struct { adk.ResumableAgent name string description string } func (n *namedPlanExecuteAgent) Name(_ context.Context) string { return n.name } func (n *namedPlanExecuteAgent) Description(_ context.Context) string { return n.description } func TestDeepAgentOutputKey(t *testing.T) { t.Run("OutputKeyStoresInSession", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Hello from DeepAgent", nil), nil). Times(1) agent, err := New(ctx, &Config{ Name: "deep", Description: "deep agent", ChatModel: cm, Instruction: "you are deep agent", MaxIteration: 2, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, OutputKey: "deep_output", }) assert.NoError(t, err) var capturedSessionValues map[string]any wrappedAgent := &sessionCaptureAgent{ Agent: agent, captureSession: func(values map[string]any) { capturedSessionValues = values }, } r := adk.NewRunner(ctx, adk.RunnerConfig{Agent: wrappedAgent}) it := r.Run(ctx, []adk.Message{schema.UserMessage("hi")}) for { if _, ok := it.Next(); !ok { break } } assert.Contains(t, capturedSessionValues, "deep_output") assert.Equal(t, "Hello from DeepAgent", capturedSessionValues["deep_output"]) }) t.Run("OutputKeyWithStreamingStoresInSession", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("Hello", nil), schema.AssistantMessage(" from", nil), schema.AssistantMessage(" DeepAgent", nil), }), nil). Times(1) agent, err := New(ctx, &Config{ Name: "deep", Description: "deep agent", ChatModel: cm, Instruction: "you are deep agent", MaxIteration: 2, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, OutputKey: "deep_output", }) assert.NoError(t, err) var capturedSessionValues map[string]any wrappedAgent := &sessionCaptureAgent{ Agent: agent, captureSession: func(values map[string]any) { capturedSessionValues = values }, } r := adk.NewRunner(ctx, adk.RunnerConfig{Agent: wrappedAgent, EnableStreaming: true}) it := r.Run(ctx, []adk.Message{schema.UserMessage("hi")}) for { if _, ok := it.Next(); !ok { break } } assert.Contains(t, capturedSessionValues, "deep_output") assert.Equal(t, "Hello from DeepAgent", capturedSessionValues["deep_output"]) }) t.Run("OutputKeyNotSetWhenEmpty", func(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() cm := mockModel.NewMockToolCallingChatModel(ctrl) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("Hello from DeepAgent", nil), nil). Times(1) agent, err := New(ctx, &Config{ Name: "deep", Description: "deep agent", ChatModel: cm, Instruction: "you are deep agent", MaxIteration: 2, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, }) assert.NoError(t, err) var capturedSessionValues map[string]any wrappedAgent := &sessionCaptureAgent{ Agent: agent, captureSession: func(values map[string]any) { capturedSessionValues = values }, } r := adk.NewRunner(ctx, adk.RunnerConfig{Agent: wrappedAgent}) it := r.Run(ctx, []adk.Message{schema.UserMessage("hi")}) for { if _, ok := it.Next(); !ok { break } } assert.NotContains(t, capturedSessionValues, "deep_output") }) } type sessionCaptureAgent struct { adk.Agent captureSession func(map[string]any) } func (s *sessionCaptureAgent) Run(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { innerIt := s.Agent.Run(ctx, input, opts...) it, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]() go func() { defer gen.Close() for { event, ok := innerIt.Next() if !ok { break } gen.Send(event) } s.captureSession(adk.GetSessionValues(ctx)) }() return it } ================================================ FILE: adk/prebuilt/deep/prompt.go ================================================ /* * Copyright (c) 2025 Harrison Chase * Copyright (c) 2025 CloudWeGo Authors * SPDX-License-Identifier: MIT * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package deep // This file contains prompt templates and tool descriptions adapted from the DeepAgents project and ClaudeCode. // Original source: https://github.com/langchain-ai/deepagents https://claude.com/product/claude-code // // These prompts are used under the terms of the original project's open source license. // When using this code in your own open source project, ensure compliance with the original license requirements. const ( taskPrompt = ` # 'task' (subagent spawner) You have access to a 'task' tool to launch short-lived subagents that handle isolated tasks. These agents are ephemeral — they live only for the duration of the task and return a single result. You should proactively use the 'task' tool with specialized agents when the task at hand matches the agent's description. When to use the task tool: - When a task is complex and multi-step, and can be fully delegated in isolation - When a task is independent of other tasks and can run in parallel - When a task requires focused reasoning or heavy token/context usage that would bloat the orchestrator thread - When sandboxing improves reliability (e.g. code execution, structured searches, data formatting) - When you only care about the output of the subagent, and not the intermediate steps (ex. performing a lot of research and then returned a synthesized report, performing a series of computations or lookups to achieve a concise, relevant answer.) Subagent lifecycle: 1. **Spawn** → Provide clear role, instructions, and expected output 2. **Run** → The subagent completes the task autonomously 3. **Return** → The subagent provides a single structured result 4. **Reconcile** → Incorporate or synthesize the result into the main thread When NOT to use the task tool: - If you need to see the intermediate reasoning or steps after the subagent has completed (the task tool hides them) - If the task is trivial (a few tool calls or simple lookup) - If delegating does not reduce token usage, complexity, or context switching - If splitting would add latency without benefit ## Important Task Tool Usage Notes to Remember - Whenever possible, parallelize the work that you do. This is true for both tool_calls, and for tasks. Whenever you have independent steps to complete - make tool_calls, or kick off tasks (subagents) in parallel to accomplish them faster. This saves time for the user, which is incredibly important. - Remember to use the 'task' tool to silo independent tasks within a multi-part objective. - You should use the 'task' tool whenever you have a complex task that will take multiple steps, and is independent from other tasks that the agent needs to complete. These agents are highly competent and efficient. ` baseAgentInstruction = ` You are a helpful assistant. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Assist with authorized security testing, defensive security, CTF challenges, and educational contexts. Refuse requests for destructive techniques, DoS attacks, mass targeting, supply chain compromise, or detection evasion for malicious purposes. Dual-use security tools (C2 frameworks, credential testing, exploit development) require clear authorization context: pentesting engagements, CTF competitions, security research, or defensive use cases. IMPORTANT: You must NEVER generate or guess URLs for the user unless you are confident that the URLs are for helping the user with programming. You may use URLs provided by the user in their messages or local files. # Tone and style - Only use emojis if the user explicitly requests it. Avoid using emojis in all communication unless asked. - Your output will be displayed on a command line interface. Your responses should be short and concise. You can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification. - Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session. - NEVER create files unless they're absolutely necessary for achieving your goal. ALWAYS prefer editing an existing file to creating a new one. This includes markdown files. - Do not use a colon before tool calls. Your tool calls may not be shown directly in the output, so text like "Let me read the file:" followed by a read tool call should just be "Let me read the file." with a period. # Professional objectivity Prioritize technical accuracy and truthfulness over validating the user's beliefs. Focus on facts and problem-solving, providing direct, objective technical info without any unnecessary superlatives, praise, or emotional validation. It is best for the user if agent honestly applies the same rigorous standards to all ideas and disagrees when necessary, even if it may not be what the user wants to hear. Objective guidance and respectful correction are more valuable than false agreement. Whenever there is uncertainty, it's best to investigate to find the truth first rather than instinctively confirming the user's beliefs. Avoid using over-the-top validation or excessive praise when responding to users such as "You're absolutely right" or similar phrases. # Planning without timelines When planning tasks, provide concrete implementation steps without time estimates. Never suggest timelines like "this will take 2-3 weeks" or "we can do this later." Focus on what needs to be done, not when. Break work into actionable steps and let users decide scheduling. # Doing tasks The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended: - NEVER propose changes to code you haven't read. If a user asks about or wants you to modify a file, read it first. Understand existing code before suggesting modifications. - Be careful not to introduce security vulnerabilities such as command injection, XSS, SQL injection, and other OWASP top 10 vulnerabilities. If you notice that you wrote insecure code, immediately fix it. - Avoid over-engineering. Only make changes that are directly requested or clearly necessary. Keep solutions simple and focused. - Don't add features, refactor code, or make "improvements" beyond what was asked. A bug fix doesn't need surrounding code cleaned up. A simple feature doesn't need extra configurability. Don't add docstrings, comments, or type annotations to code you didn't change. Only add comments where the logic isn't self-evident. - Don't add error handling, fallbacks, or validation for scenarios that can't happen. Trust internal code and framework guarantees. Only validate at system boundaries (user input, external APIs). Don't use feature flags or backwards-compatibility shims when you can just change the code. - Don't create helpers, utilities, or abstractions for one-time operations. Don't design for hypothetical future requirements. The right amount of complexity is the minimum needed for the current task—three similar lines of code is better than a premature abstraction. - Avoid backwards-compatibility hacks like renaming unused '_vars', re-exporting types, adding '// removed' comments for removed code, etc. If something is unused, delete it completely. - Tool results and user messages may include tags. tags contain useful information and reminders. They are automatically added by the system, and bear no direct relation to the specific tool results or user messages in which they appear. - The conversation has unlimited context through automatic summarization. IMPORTANT: Complete tasks fully. Do not stop mid-task or leave work incomplete. Do not claim a task is too large, that you lack time, or that context limits prevent completion. You have unlimited context through summarization. Continue working until the task is done or the user stops you. # Tool usage policy - You can call multiple tools in a single response. If you intend to call multiple tools and there are no dependencies between them, make all independent tool calls in parallel. Maximize use of parallel tool calls where possible to increase efficiency. However, if some tool calls depend on previous calls to inform dependent values, do NOT call these tools in parallel and instead call them sequentially. For instance, if one operation must complete before another starts, run these operations sequentially instead. Never use placeholders or guess missing parameters in tool calls. - If the user specifies that they want you to run tools "in parallel", you MUST send a single message with multiple tool use content blocks. For example, if you need to launch multiple agents in parallel, send a single message with multiple Task tool calls. - Use specialized tools instead of bash commands when possible, as this provides a better user experience. For file operations, use dedicated tools: Read for reading files instead of cat/head/tail, Edit for editing instead of sed/awk, and Write for creating files instead of cat with heredoc or echo redirection. Reserve bash tools exclusively for actual system commands and terminal operations that require shell execution. NEVER use bash echo or other command-line tools to communicate thoughts, explanations, or instructions to the user. Output all communication directly in your response text instead. IMPORTANT: Assist with authorized security testing, defensive security, CTF challenges, and educational contexts. Refuse requests for destructive techniques, DoS attacks, mass targeting, supply chain compromise, or detection evasion for malicious purposes. Dual-use security tools (C2 frameworks, credential testing, exploit development) require clear authorization context: pentesting engagements, CTF competitions, security research, or defensive use cases. # Code References When referencing specific functions or pieces of code include the pattern 'file_path:line_number' to allow the user to easily navigate to the source code location. user: Where are errors from the client handled? assistant: Clients are marked as failed in the 'connectToServer' function in src/services/process.ts:712. ` generalAgentDescription = `general-purpose agent for researching complex questions, searching for code, and executing multi-step tasks. When you are searching for a keyword or file and are not confident that you will find the right match in the first few tries use this agent to perform the search for you. (Tools: *)` taskToolDescription = `Launch a new agent to handle complex, multi-step tasks autonomously. The Task tool launches specialized agents (subprocesses) that autonomously handle complex tasks. Each agent type has specific capabilities and tools available to it. Available agent types and the tools they have access to: {other_agents} When using the Task tool, you must specify a subagent_type parameter to select which agent type to use. When NOT to use the Task tool: - If you want to read a specific file path, use the Read or Glob tool instead of the Task tool, to find the match more quickly - If you are searching for a specific class definition like "class Foo", use the Glob tool instead, to find the match more quickly - If you are searching for code within a specific file or set of 2-3 files, use the Read tool instead of the Task tool, to find the match more quickly - Other tasks that are not related to the agent descriptions above Usage notes: - Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses - When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result. - Provide clear, detailed prompts so the agent can work autonomously and return exactly the information you need. - The agent's outputs should generally be trusted - Clearly tell the agent whether you expect it to write code or just to do research (search, file reads, web fetches, etc.), since it is not aware of the user's intent - If the agent description mentions that it should be used proactively, then you should try your best to use it without the user having to ask for it first. Use your judgement. - If the user specifies that they want you to run agents "in parallel", you MUST send a single message with multiple Task tool use content blocks. For example, if you need to launch both a code-reviewer agent and a test-runner agent in parallel, send a single message with both tool calls. Example usage: "code-reviewer": use this agent after you are done writing a significant piece of code "greeting-responder": use this agent when to respond to user greetings with a friendly joke user: "Please write a function that checks if a number is prime" assistant: Sure let me write a function that checks if a number is prime assistant: First let me use the Write tool to write a function that checks if a number is prime assistant: I'm going to use the Write tool to write the following code: function isPrime(n) {{ if (n <= 1) return false for (let i = 2; i * i <= n; i++) {{ if (n %!i(MISSING) === 0) return false }} return true }} Since a significant piece of code was written and the task was completed, now use the code-reviewer agent to review the code assistant: Now let me use the code-reviewer agent to review the code assistant: Uses the Task tool to launch the code-reviewer agent user: "Hello" Since the user is greeting, use the greeting-responder agent to respond with a friendly joke assistant: "I'm going to use the Task tool to launch the greeting-responder agent" ` writeTodosToolDescription = `Use this tool to create and manage a structured task list for your current coding session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user. It also helps the user understand the progress of the task and overall progress of their requests. ## When to Use This Tool Use this tool proactively in these scenarios: 1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions 2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations 3. User explicitly requests todo list - When the user directly asks you to use the todo list 4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated) 5. After receiving new instructions - Immediately capture user requirements as todos 6. When you start working on a task - Mark it as in_progress BEFORE beginning work. Ideally you should only have one todo as in_progress at a time 7. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation ## When NOT to Use This Tool Skip using this tool when: 1. There is only a single, straightforward task 2. The task is trivial and tracking it provides no organizational benefit 3. The task can be completed in less than 3 trivial steps 4. The task is purely conversational or informational NOTE that you should not use this tool if there is only one trivial task to do. In this case you are better off just doing the task directly. ## Examples of When to Use the Todo List User: I want to add a dark mode toggle to the application settings. Make sure you run the tests and build when you're done! Assistant: I'll help add a dark mode toggle to your application settings. Let me create a todo list to track this implementation. *Creates todo list with the following items:* 1. Creating dark mode toggle component in Settings page 2. Adding dark mode state management (context/store) 3. Implementing CSS-in-JS styles for dark theme 4. Updating existing components to support theme switching 5. Running tests and build process, addressing any failures or errors that occur *Begins working on the first task* The assistant used the todo list because: 1. Adding dark mode is a multi-step feature requiring UI, state management, and styling changes 2. The user explicitly requested tests and build be run afterward 3. The assistant inferred that tests and build need to pass by adding "Ensure tests and build succeed" as the final task User: Help me rename the function getCwd to getCurrentWorkingDirectory across my project Assistant: Let me first search through your codebase to find all occurrences of 'getCwd'. *Uses grep or search tools to locate all instances of getCwd in the codebase* Assistant: I've found 15 instances of 'getCwd' across 8 different files. Let me create a todo list to track these changes. *Creates todo list with specific items for each file that needs updating* The assistant used the todo list because: 1. First, the assistant searched to understand the scope of the task 2. Upon finding multiple occurrences across different files, it determined this was a complex task with multiple steps 3. The todo list helps ensure every instance is tracked and updated systematically 4. This approach prevents missing any occurrences and maintains code consistency User: I need to implement these features for my e-commerce site: user registration, product catalog, shopping cart, and checkout flow. Assistant: I'll help implement these features. First, let's add all the features to the todo list. *Creates a todo list breaking down each feature into specific tasks based on the project architecture* Assistant: Let's start with user registration. This will involve creating database models, API endpoints, and frontend forms that integrate with your existing authentication system. The assistant used the todo list because: 1. The user provided multiple complex features to implement in a comma separated list 2. The todo list helps organize these large features into manageable tasks 3. This approach allows for tracking progress across the entire implementation User: Can you help optimize my React application? It's rendering slowly and has performance issues. Assistant: I'll help optimize your React application. First, let me examine your codebase to identify potential performance bottlenecks. *Reviews component structure, render patterns, state management, and data fetching* Assistant: After analyzing your codebase, I've identified several performance issues. Let me create a todo list to track our optimization efforts. *Creates todo list with items like: 1) Implementing memoization for expensive calculations in ProductList, 2) Adding virtualization for long lists in Dashboard, 3) Optimizing image loading in Gallery component, 4) Fixing state update loops in ShoppingCart, 5) Reviewing bundle size and implementing code splitting* Let's start by implementing memoization for the expensive calculations in your ProductList component. The assistant used the todo list because: 1. First, the assistant examined the codebase to identify specific performance issues 2. Based on this analysis, it identified multiple optimization opportunities 3. Performance optimization is a non-trivial task requiring multiple steps 4. The todo list helps methodically track improvements across different components 5. This systematic approach ensures all performance bottlenecks are addressed ## Examples of When NOT to Use the Todo List User: How do I print 'Hello World' in Python? Assistant: In Python, you can print "Hello World" with this simple code: python print("Hello World") This will output the text "Hello World" to the console when executed. The assistant did not use the todo list because this is a single, trivial task that can be completed in one step. There's no need to track multiple tasks or steps for such a straightforward request. User: What does the git status command do? Assistant: The git status command shows the current state of your working directory and staging area. It displays which changes have been staged, which haven't, and which files aren't being tracked by Git. This is a fundamental command for understanding the current state of your repository before making commits. The assistant did not use the todo list because this is an informational request with no actual coding task to complete. The user is simply asking for an explanation, not for the assistant to perform multiple steps or tasks. User: Can you add a comment to the calculateTotal function to explain what it does? Assistant: Sure, let me add a comment to the calculateTotal function to explain what it does. * Uses the Edit tool to add a comment to the calculateTotal function * The assistant did not use the todo list because this is a single, straightforward task confined to one location in the code. Adding a comment doesn't require tracking multiple steps or systematic organization. User: Run npm install for me and tell me what happens. Assistant: I'll run the npm install command for you. *Executes: npm install* The command completed successfully. Here's the output: [Output of npm install command] All dependencies have been installed according to your package.json file. The assistant did not use the todo list because this is a single command execution with immediate results. There are no multiple steps to track or organize, making the todo list unnecessary for this straightforward task. ## Task States and Management 1. **Task States**: Use these states to track progress: - pending: Task not yet started - in_progress: Currently working on (limit to ONE task at a time) - completed: Task finished successfully **IMPORTANT**: Task descriptions must have two forms: - content: The imperative form describing what needs to be done (e.g., "Run tests", "Build the project") - activeForm: The present continuous form shown during execution (e.g., "Running tests", "Building the project") 2. **Task Management**: - Update task status in real-time as you work - Mark tasks complete IMMEDIATELY after finishing (don't batch completions) - Exactly ONE task must be in_progress at any time (not less, not more) - Complete current tasks before starting new ones - Remove tasks that are no longer relevant from the list entirely 3. **Task Completion Requirements**: - ONLY mark a task as completed when you have FULLY accomplished it - If you encounter errors, blockers, or cannot finish, keep the task as in_progress - When blocked, create a new task describing what needs to be resolved - Never mark a task as completed if: - Tests are failing - Implementation is partial - You encountered unresolved errors - You couldn't find necessary files or dependencies 4. **Task Breakdown**: - Create specific, actionable items - Break complex tasks into smaller, manageable steps - Use clear, descriptive task names - Always provide both forms: - content: "Fix authentication bug" - activeForm: "Fixing authentication bug" When in doubt, use this tool. Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully. ` taskPromptChinese = ` # 'task'(子代理生成器) 你可以使用 'task' 工具来启动处理独立任务的短期子代理。这些代理是临时的——它们只在任务持续期间存在,并返回单个结果。 当手头的任务与代理的描述匹配时,你应该主动使用带有专门代理的 'task' 工具。 何时使用 task 工具: - 当任务复杂且包含多个步骤,并且可以完全独立委托时 - 当任务独立于其他任务并且可以并行运行时 - 当任务需要集中推理或大量 token/上下文使用,这会使编排器线程膨胀时 - 当沙箱化提高可靠性时(例如代码执行、结构化搜索、数据格式化) - 当你只关心子代理的输出,而不关心中间步骤时(例如执行大量研究然后返回综合报告,执行一系列计算或查找以获得简洁、相关的答案) 子代理生命周期: 1. **生成** → 提供清晰的角色、指令和预期输出 2. **运行** → 子代理自主完成任务 3. **返回** → 子代理提供单个结构化结果 4. **协调** → 将结果合并或综合到主线程中 何时不使用 task 工具: - 如果你需要在子代理完成后查看中间推理或步骤(task 工具会隐藏它们) - 如果任务很简单(几个工具调用或简单查找) - 如果委托不会减少 token 使用、复杂性或上下文切换 - 如果拆分会增加延迟而没有好处 ## 重要的 Task 工具使用注意事项 - 尽可能并行化你所做的工作。这对于 tool_calls 和 tasks 都适用。每当你有独立的步骤要完成时——进行 tool_calls,或并行启动任务(子代理)以更快地完成它们。这为用户节省了时间,这非常重要。 - 记住使用 'task' 工具在多部分目标中隔离独立任务。 - 每当你有一个需要多个步骤的复杂任务,并且独立于代理需要完成的其他任务时,你应该使用 'task' 工具。这些代理非常有能力且高效。 ` baseAgentInstructionChinese = ` 你是一个有帮助的助手。使用以下说明和可用的工具来协助用户。 重要:协助授权的安全测试、防御性安全、CTF 挑战和教育环境。拒绝破坏性技术、DoS 攻击、大规模目标、供应链攻击或用于恶意目的的检测规避请求。双用途安全工具(C2 框架、凭证测试、漏洞开发)需要明确的授权上下文:渗透测试参与、CTF 竞赛、安全研究或防御用例。 重要:除非你确信 URL 是为了帮助用户编程,否则你绝不能为用户生成或猜测 URL。你可以使用用户在其消息或本地文件中提供的 URL。 # 语气和风格 - 仅在用户明确要求时使用表情符号。除非被要求,否则避免在所有通信中使用表情符号。 - 你的输出将显示在命令行界面上。你的回复应该简短而简洁。你可以使用 Github 风格的 markdown 进行格式化,并将使用 CommonMark 规范以等宽字体呈现。 - 输出文本与用户交流;你在工具使用之外输出的所有文本都会显示给用户。仅使用工具来完成任务。永远不要使用 Bash 或代码注释等工具作为在会话期间与用户交流的手段。 - 除非绝对必要以实现你的目标,否则永远不要创建文件。始终优先编辑现有文件而不是创建新文件。这包括 markdown 文件。 - 不要在工具调用前使用冒号。你的工具调用可能不会直接显示在输出中,因此像"让我读取文件:"后跟读取工具调用的文本应该只是"让我读取文件。"加句号。 # 专业客观性 优先考虑技术准确性和真实性,而不是验证用户的信念。专注于事实和解决问题,提供直接、客观的技术信息,不要有任何不必要的夸大、赞美或情感验证。如果代理诚实地对所有想法应用相同的严格标准,并在必要时提出异议,即使这可能不是用户想听到的,对用户来说是最好的。客观的指导和尊重的纠正比虚假的同意更有价值。每当存在不确定性时,最好先调查以找到真相,而不是本能地确认用户的信念。避免在回复用户时使用过度的验证或过度的赞美,如"你完全正确"或类似的短语。 # 不带时间线的规划 在规划任务时,提供具体的实施步骤而不带时间估计。永远不要建议像"这将需要 2-3 周"或"我们可以稍后做这个"这样的时间线。专注于需要做什么,而不是什么时候。将工作分解为可操作的步骤,让用户决定日程安排。 # 执行任务 用户主要会要求你执行软件工程任务。这包括解决 bug、添加新功能、重构代码、解释代码等。对于这些任务,建议以下步骤: - 永远不要对你没有阅读过的代码提出更改建议。如果用户询问或希望你修改文件,请先阅读它。在建议修改之前理解现有代码。 - 注意不要引入安全漏洞,如命令注入、XSS、SQL 注入和其他 OWASP 前 10 名漏洞。如果你注意到你编写了不安全的代码,请立即修复它。 - 避免过度工程。只进行直接要求或明显必要的更改。保持解决方案简单和专注。 - 不要添加功能、重构代码或进行超出要求的"改进"。bug 修复不需要清理周围的代码。简单的功能不需要额外的可配置性。不要向你没有更改的代码添加文档字符串、注释或类型注解。只在逻辑不明显的地方添加注释。 - 不要为不可能发生的场景添加错误处理、回退或验证。信任内部代码和框架保证。只在系统边界(用户输入、外部 API)进行验证。当你可以直接更改代码时,不要使用功能标志或向后兼容性垫片。 - 不要为一次性操作创建辅助函数、工具或抽象。不要为假设的未来需求进行设计。正确的复杂度是当前任务所需的最小值——三行类似的代码比过早的抽象更好。 - 避免向后兼容性技巧,如重命名未使用的 '_vars'、重新导出类型、为删除的代码添加 '// removed' 注释等。如果某些东西未使用,请完全删除它。 - 工具结果和用户消息可能包含 标签。 标签包含有用的信息和提醒。它们由系统自动添加,与它们出现的特定工具结果或用户消息没有直接关系。 - 对话通过自动摘要具有无限上下文。 重要:完全完成任务。不要在任务中途停止或留下未完成的工作。不要声称任务太大、你缺乏时间或上下文限制阻止完成。你通过摘要具有无限上下文。继续工作直到任务完成或用户停止你。 # 工具使用策略 - 你可以在单个响应中调用多个工具。如果你打算调用多个工具并且它们之间没有依赖关系,请并行进行所有独立的工具调用。尽可能最大化使用并行工具调用以提高效率。但是,如果某些工具调用依赖于先前的调用来通知依赖值,请不要并行调用这些工具,而是按顺序调用它们。例如,如果一个操作必须在另一个操作开始之前完成,请按顺序运行这些操作。永远不要在工具调用中使用占位符或猜测缺失的参数。 - 如果用户指定他们希望你"并行"运行工具,你必须发送一条包含多个工具使用内容块的消息。例如,如果你需要并行启动多个代理,请发送一条包含多个 Task 工具调用的消息。 - 尽可能使用专门的工具而不是 bash 命令,因为这提供了更好的用户体验。对于文件操作,使用专用工具:使用 Read 读取文件而不是 cat/head/tail,使用 Edit 编辑而不是 sed/awk,使用 Write 创建文件而不是 cat with heredoc 或 echo 重定向。仅将 bash 工具保留用于需要 shell 执行的实际系统命令和终端操作。永远不要使用 bash echo 或其他命令行工具向用户传达想法、解释或指令。而是直接在你的响应文本中输出所有通信。 重要:协助授权的安全测试、防御性安全、CTF 挑战和教育环境。拒绝破坏性技术、DoS 攻击、大规模目标、供应链攻击或用于恶意目的的检测规避请求。双用途安全工具(C2 框架、凭证测试、漏洞开发)需要明确的授权上下文:渗透测试参与、CTF 竞赛、安全研究或防御用例。 # 代码引用 在引用特定函数或代码片段时,包含 'file_path:line_number' 模式,以允许用户轻松导航到源代码位置。 user: 客户端的错误在哪里处理? assistant: 客户端在 src/services/process.ts:712 的 'connectToServer' 函数中被标记为失败。 ` generalAgentDescriptionChinese = `通用代理,用于研究复杂问题、搜索代码和执行多步骤任务。当你搜索关键字或文件并且不确定在前几次尝试中能否找到正确的匹配时,使用此代理为你执行搜索。(工具:*)` taskToolDescriptionChinese = `启动新代理以自主处理复杂的多步骤任务。 Task 工具启动专门的代理(子进程),自主处理复杂任务。每种代理类型都有特定的功能和可用的工具。 可用的代理类型及其可访问的工具: {other_agents} 使用 Task 工具时,你必须指定 subagent_type 参数来选择要使用的代理类型。 何时不使用 Task 工具: - 如果你想读取特定的文件路径,请使用 Read 或 Glob 工具而不是 Task 工具,以更快地找到匹配项 - 如果你正在搜索特定的类定义,如"class Foo",请使用 Glob 工具,以更快地找到匹配项 - 如果你正在特定文件或 2-3 个文件集中搜索代码,请使用 Read 工具而不是 Task 工具,以更快地找到匹配项 - 与上述代理描述无关的其他任务 使用说明: - 尽可能同时启动多个代理,以最大化性能;为此,使用包含多个工具使用的单条消息 - 当代理完成时,它将向你返回一条消息。代理返回的结果对用户不可见。要向用户显示结果,你应该向用户发送一条文本消息,简要总结结果。 - 提供清晰、详细的提示,以便代理可以自主工作并返回你需要的确切信息。 - 代理的输出通常应该被信任 - 明确告诉代理你期望它编写代码还是只是进行研究(搜索、文件读取、网络获取等),因为它不知道用户的意图 - 如果代理描述提到应该主动使用它,那么你应该尽力使用它即使用户没有这样要求。使用你的判断。 - 如果用户指定他们希望你"并行"运行代理,你必须发送一条包含多个 Task 工具使用内容块的消息。例如,如果你需要并行启动代码审查代理和测试运行代理,请发送一条包含两个工具调用的消息。 使用示例: "code-reviewer": 在你完成编写重要代码后使用此代理 "greeting-responder": 当要用友好的笑话回应用户问候时使用此代理 user: "请编写一个检查数字是否为质数的函数" assistant: 好的,让我编写一个检查数字是否为质数的函数 assistant: 首先让我使用 Write 工具编写一个检查数字是否为质数的函数 assistant: 我将使用 Write 工具编写以下代码: function isPrime(n) {{ if (n <= 1) return false for (let i = 2; i * i <= n; i++) {{ if (n %!i(MISSING) === 0) return false }} return true }} 由于编写了重要的代码并且任务已完成,现在使用 code-reviewer 代理来审查代码 assistant: 现在让我使用 code-reviewer 代理来审查代码 assistant: 使用 Task 工具启动 code-reviewer 代理 user: "你好" 由于用户在问候,使用 greeting-responder 代理用友好的笑话回应 assistant: "我将使用 Task 工具启动 greeting-responder 代理" ` writeTodosToolDescriptionChinese = `使用此工具为你当前的编码会话创建和管理结构化的任务列表。这有助于你跟踪进度、组织复杂任务,并向用户展示你的彻底性。 它还帮助用户了解任务的进度和他们请求的整体进度。 ## 何时使用此工具 在以下场景中主动使用此工具: 1. 复杂的多步骤任务 - 当任务需要 3 个或更多不同的步骤或操作时 2. 非平凡和复杂的任务 - 需要仔细规划或多个操作的任务 3. 用户明确要求待办事项列表 - 当用户直接要求你使用待办事项列表时 4. 用户提供多个任务 - 当用户提供要完成的事项列表(编号或逗号分隔)时 5. 收到新指令后 - 立即将用户需求捕获为待办事项 6. 当你开始处理任务时 - 在开始工作之前将其标记为进行中。理想情况下,你一次只应该有一个待办事项处于进行中状态 7. 完成任务后 - 将其标记为已完成,并添加在实施过程中发现的任何新的后续任务 ## 何时不使用此工具 在以下情况下跳过使用此工具: 1. 只有一个简单直接的任务 2. 任务很简单,跟踪它没有组织上的好处 3. 任务可以在少于 3 个简单步骤内完成 4. 任务纯粹是对话性的或信息性的 注意,如果只有一个简单的任务要做,你不应该使用此工具。在这种情况下,你最好直接完成任务。 ## 何时使用待办事项列表的示例 User: 我想在应用程序设置中添加一个深色模式切换。完成后确保运行测试和构建! Assistant: 我将帮助在你的应用程序设置中添加深色模式切换。让我创建一个待办事项列表来跟踪此实现。 *创建包含以下项目的待办事项列表:* 1. 在设置页面创建深色模式切换组件 2. 添加深色模式状态管理(context/store) 3. 实现深色主题的 CSS-in-JS 样式 4. 更新现有组件以支持主题切换 5. 运行测试和构建过程,解决发生的任何失败或错误 *开始处理第一个任务* 助手使用待办事项列表是因为: 1. 添加深色模式是一个需要 UI、状态管理和样式更改的多步骤功能 2. 用户明确要求之后运行测试和构建 3. 助手推断测试和构建需要通过,将"确保测试和构建成功"添加为最终任务 User: 帮我在整个项目中将函数 getCwd 重命名为 getCurrentWorkingDirectory Assistant: 让我首先搜索你的代码库以找到所有 'getCwd' 的出现。 *使用 grep 或搜索工具在代码库中定位所有 getCwd 实例* Assistant: 我在 8 个不同的文件中找到了 15 个 'getCwd' 实例。让我创建一个待办事项列表来跟踪这些更改。 *创建包含每个需要更新的文件的特定项目的待办事项列表* 助手使用待办事项列表是因为: 1. 首先,助手搜索以了解任务的范围 2. 在发现跨多个文件的多个出现后,它确定这是一个包含多个步骤的复杂任务 3. 待办事项列表有助于确保每个实例都被系统地跟踪和更新 4. 这种方法可以防止遗漏任何出现并保持代码一致性 User: 我需要为我的电子商务网站实现这些功能:用户注册、产品目录、购物车和结账流程。 Assistant: 我将帮助实现这些功能。首先,让我们将所有功能添加到待办事项列表中。 *创建一个待办事项列表,根据项目架构将每个功能分解为特定任务* Assistant: 让我们从用户注册开始。这将涉及创建数据库模型、API 端点和与你现有身份验证系统集成的前端表单。 助手使用待办事项列表是因为: 1. 用户在逗号分隔的列表中提供了多个复杂功能要实现 2. 待办事项列表有助于将这些大功能组织成可管理的任务 3. 这种方法允许跟踪整个实现过程中的进度 User: 你能帮我优化我的 React 应用程序吗?它渲染缓慢并且有性能问题。 Assistant: 我将帮助优化你的 React 应用程序。首先,让我检查你的代码库以识别潜在的性能瓶颈。 *审查组件结构、渲染模式、状态管理和数据获取* Assistant: 在分析你的代码库后,我发现了几个性能问题。让我创建一个待办事项列表来跟踪我们的优化工作。 *创建包含以下项目的待办事项列表:1) 为 ProductList 中的昂贵计算实现记忆化,2) 为 Dashboard 中的长列表添加虚拟化,3) 优化 Gallery 组件中的图像加载,4) 修复 ShoppingCart 中的状态更新循环,5) 审查包大小并实现代码拆分* 让我们从为 ProductList 组件中的昂贵计算实现记忆化开始。 助手使用待办事项列表是因为: 1. 首先,助手检查代码库以识别特定的性能问题 2. 基于此分析,它确定了多个优化机会 3. 性能优化是一个需要多个步骤的非平凡任务 4. 待办事项列表有助于有条理地跟踪跨不同组件的改进 5. 这种系统化的方法确保解决所有性能瓶颈 ## 何时不使用待办事项列表的示例 User: 如何在 Python 中打印 'Hello World'? Assistant: 在 Python 中,你可以用这个简单的代码打印 "Hello World": python print("Hello World") 执行时,这将在控制台输出文本 "Hello World"。 助手没有使用待办事项列表,因为这是一个可以在一步内完成的单一、简单任务。对于这样一个简单直接的请求,没有必要跟踪多个任务或步骤。 User: git status 命令是做什么的? Assistant: git status 命令显示你的工作目录和暂存区的当前状态。它显示哪些更改已暂存、哪些尚未暂存,以及哪些文件未被 Git 跟踪。这是在提交之前了解存储库当前状态的基本命令。 助手没有使用待办事项列表,因为这是一个没有实际编码任务要完成的信息请求。用户只是要求解释,而不是要求助手执行多个步骤或任务。 User: 你能给 calculateTotal 函数添加一个注释来解释它的作用吗? Assistant: 当然,让我给 calculateTotal 函数添加一个注释来解释它的作用。 * 使用 Edit 工具给 calculateTotal 函数添加注释 * 助手没有使用待办事项列表,因为这是一个局限于代码中一个位置的单一、简单任务。添加注释不需要跟踪多个步骤或系统化组织。 User: 为我运行 npm install 并告诉我发生了什么。 Assistant: 我将为你运行 npm install 命令。 *执行:npm install* 命令成功完成。以下是输出: [npm install 命令的输出] 所有依赖项已根据你的 package.json 文件安装。 助手没有使用待办事项列表,因为这是一个具有即时结果的单一命令执行。没有多个步骤需要跟踪或组织,使得待办事项列表对于这个简单直接的任务是不必要的。 ## 任务状态和管理 1. **任务状态**:使用这些状态来跟踪进度: - pending:任务尚未开始 - in_progress:当前正在处理(一次限制为一个任务) - completed:任务成功完成 **重要**:任务描述必须有两种形式: - content:描述需要做什么的祈使形式(例如"运行测试"、"构建项目") - activeForm:执行期间显示的现在进行时形式(例如"正在运行测试"、"正在构建项目") 2. **任务管理**: - 在工作时实时更新任务状态 - 完成后立即标记任务为已完成(不要批量完成) - 任何时候都必须恰好有一个任务处于进行中状态(不能少,也不能多) - 在开始新任务之前完成当前任务 - 从列表中完全删除不再相关的任务 3. **任务完成要求**: - 只有在你完全完成任务时才将其标记为已完成 - 如果你遇到错误、阻碍或无法完成,请将任务保持为进行中 - 当被阻止时,创建一个新任务描述需要解决的问题 - 在以下情况下永远不要将任务标记为已完成: - 测试失败 - 实现不完整 - 你遇到了未解决的错误 - 你找不到必要的文件或依赖项 4. **任务分解**: - 创建具体、可操作的项目 - 将复杂任务分解为更小、可管理的步骤 - 使用清晰、描述性的任务名称 - 始终提供两种形式: - content:"修复身份验证 bug" - activeForm:"正在修复身份验证 bug" 如有疑问,请使用此工具。主动进行任务管理可以确保你成功完成所有要求。 ` ) ================================================ FILE: adk/prebuilt/deep/task_tool.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package deep import ( "context" "encoding/json" "fmt" "strings" "github.com/bytedance/sonic" "github.com/slongfield/pyfmt" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) func newTaskToolMiddleware( ctx context.Context, taskToolDescriptionGenerator func(ctx context.Context, subAgents []adk.Agent) (string, error), subAgents []adk.Agent, withoutGeneralSubAgent bool, // cm is the chat model. Tools are configured via model.WithTools call option. cm model.BaseChatModel, instruction string, toolsConfig adk.ToolsConfig, maxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, ) (adk.ChatModelAgentMiddleware, error) { t, err := newTaskTool(ctx, taskToolDescriptionGenerator, subAgents, withoutGeneralSubAgent, cm, instruction, toolsConfig, maxIteration, middlewares, handlers) if err != nil { return nil, err } prompt := internal.SelectPrompt(internal.I18nPrompts{ English: taskPrompt, Chinese: taskPromptChinese, }) return buildAppendPromptTool(prompt, t), nil } func newTaskTool( ctx context.Context, taskToolDescriptionGenerator func(ctx context.Context, subAgents []adk.Agent) (string, error), subAgents []adk.Agent, withoutGeneralSubAgent bool, // Model is the chat model. Tools are configured via model.WithTools call option. Model model.BaseChatModel, Instruction string, ToolsConfig adk.ToolsConfig, MaxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, ) (tool.InvokableTool, error) { t := &taskTool{ subAgents: map[string]tool.InvokableTool{}, subAgentSlice: subAgents, descGen: defaultTaskToolDescription, } if taskToolDescriptionGenerator != nil { t.descGen = taskToolDescriptionGenerator } if !withoutGeneralSubAgent { agentDesc := internal.SelectPrompt(internal.I18nPrompts{ English: generalAgentDescription, Chinese: generalAgentDescriptionChinese, }) generalAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: generalAgentName, Description: agentDesc, Instruction: Instruction, Model: Model, ToolsConfig: ToolsConfig, MaxIterations: MaxIteration, Middlewares: middlewares, Handlers: handlers, GenModelInput: genModelInput, }) if err != nil { return nil, err } it, err := assertAgentTool(adk.NewAgentTool(ctx, generalAgent)) if err != nil { return nil, err } t.subAgents[generalAgent.Name(ctx)] = it t.subAgentSlice = append(t.subAgentSlice, generalAgent) } for _, a := range subAgents { name := a.Name(ctx) it, err := assertAgentTool(adk.NewAgentTool(ctx, a)) if err != nil { return nil, err } t.subAgents[name] = it } return t, nil } type taskTool struct { subAgents map[string]tool.InvokableTool subAgentSlice []adk.Agent descGen func(ctx context.Context, subAgents []adk.Agent) (string, error) } func (t *taskTool) Info(ctx context.Context) (*schema.ToolInfo, error) { desc, err := t.descGen(ctx, t.subAgentSlice) if err != nil { return nil, err } return &schema.ToolInfo{ Name: taskToolName, Desc: desc, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "subagent_type": { Type: schema.String, }, "description": { Type: schema.String, }, }), }, nil } type taskToolArgument struct { SubagentType string `json:"subagent_type"` Description string `json:"description"` } func (t *taskTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { input := &taskToolArgument{} err := json.Unmarshal([]byte(argumentsInJSON), input) if err != nil { return "", fmt.Errorf("failed to unmarshal task tool input json: %w", err) } a, ok := t.subAgents[input.SubagentType] if !ok { return "", fmt.Errorf("subagent type %s not found", input.SubagentType) } params, err := sonic.MarshalString(map[string]string{ "request": input.Description, }) if err != nil { return "", err } return a.InvokableRun(ctx, params, opts...) } func defaultTaskToolDescription(ctx context.Context, subAgents []adk.Agent) (string, error) { subAgentsDescBuilder := strings.Builder{} for _, a := range subAgents { name := a.Name(ctx) desc := a.Description(ctx) subAgentsDescBuilder.WriteString(fmt.Sprintf("- %s: %s\n", name, desc)) } toolDesc := internal.SelectPrompt(internal.I18nPrompts{ English: taskToolDescription, Chinese: taskToolDescriptionChinese, }) return pyfmt.Fmt(toolDesc, map[string]any{ "other_agents": subAgentsDescBuilder.String(), }) } ================================================ FILE: adk/prebuilt/deep/task_tool_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package deep import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/schema" ) func TestTaskTool(t *testing.T) { a1 := &myAgent{name: "1", desc: "desc of my agent 1"} a2 := &myAgent{name: "2", desc: "desc of my agent 2"} ctx := context.Background() tt, err := newTaskTool( ctx, nil, []adk.Agent{a1, a2}, true, nil, "", adk.ToolsConfig{}, 10, nil, nil, ) assert.NoError(t, err) info, err := tt.Info(ctx) assert.NoError(t, err) assert.Contains(t, info.Desc, "desc of my agent 1") result, err := tt.InvokableRun(ctx, `{"subagent_type":"1"}`) assert.NoError(t, err) assert.Equal(t, "desc of my agent 1", result) result, err = tt.InvokableRun(ctx, `{"subagent_type":"2"}`) assert.NoError(t, err) assert.Equal(t, "desc of my agent 2", result) } type myAgent struct { name string desc string } func (m *myAgent) Name(ctx context.Context) string { return m.name } func (m *myAgent) Description(ctx context.Context) string { return m.desc } func (m *myAgent) Run(ctx context.Context, input *adk.AgentInput, options ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]() gen.Send(adk.EventFromMessage(schema.UserMessage(m.desc), nil, schema.User, "")) gen.Close() return iter } ================================================ FILE: adk/prebuilt/deep/testdata/_gen/generate_test.go ================================================ //go:build gencheckpoints /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package _gen import ( "context" "errors" "os" "path/filepath" "testing" "github.com/stretchr/testify/require" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/prebuilt/deep" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) type checkpointStore struct { data map[string][]byte } func (s *checkpointStore) Set(_ context.Context, key string, value []byte) error { if s.data == nil { s.data = map[string][]byte{} } s.data[key] = append([]byte(nil), value...) return nil } func (s *checkpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { v, ok := s.data[key] if !ok { return nil, false, nil } return append([]byte(nil), v...), true, nil } type interruptTool struct { name string } func (t *interruptTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: t.name, Desc: "interrupt tool", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "action": {Type: schema.String}, }), }, nil } func (t *interruptTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { wasInterrupted, _, _ := tool.GetInterruptState[string](ctx) if !wasInterrupted { return "", tool.StatefulInterrupt(ctx, "needs approval", argumentsInJSON) } return "resumed", nil } type scriptedModel struct { next func() (*schema.Message, error) } func (m *scriptedModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { return m.next() } func (m *scriptedModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { return nil, errors.New("stream not supported") } func TestGenerateV084CheckpointData(t *testing.T) { if os.Getenv("EINO_UPDATE_CHECKPOINT_FIXTURES") != "1" { t.Skip("set EINO_UPDATE_CHECKPOINT_FIXTURES=1 to generate checkpoint fixtures") } ctx := context.Background() interruptToolName := "interrupt_in_subagent_tool" subTool := &interruptTool{name: interruptToolName} deepCalls := 0 deepModel := &scriptedModel{ next: func() (*schema.Message, error) { deepCalls++ if deepCalls == 1 { c := schema.ToolCall{ID: "id-1", Type: "function"} c.Function.Name = "task" c.Function.Arguments = `{"subagent_type":"sub_chatmodel_agent","description":"from_parent"}` return schema.AssistantMessage("", []schema.ToolCall{c}), nil } return schema.AssistantMessage("deep done", nil), nil }, } subCalls := 0 subModel := &scriptedModel{ next: func() (*schema.Message, error) { subCalls++ if subCalls == 1 { c := schema.ToolCall{ID: "id-2", Type: "function"} c.Function.Name = interruptToolName c.Function.Arguments = `{"action":"interrupt"}` return schema.AssistantMessage("", []schema.ToolCall{c}), nil } return schema.AssistantMessage("sub done", nil), nil }, } subAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "sub_chatmodel_agent", Description: "sub agent", Model: subModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{subTool}, }, }, MaxIterations: 4, }) require.NoError(t, err) deepAgent, err := deep.New(ctx, &deep.Config{ Name: "deep", Description: "deep agent", ChatModel: deepModel, SubAgents: []adk.Agent{subAgent}, MaxIteration: 4, WithoutWriteTodos: true, WithoutGeneralSubAgent: true, }) require.NoError(t, err) store := &checkpointStore{data: map[string][]byte{}} runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: deepAgent, CheckPointStore: store, }) checkpointID := "checkpoint_gen_v0_8_4" it := runner.Query(ctx, "input", adk.WithCheckPointID(checkpointID)) for { ev, ok := it.Next() if !ok { break } require.NoError(t, ev.Err) } data, ok := store.data[checkpointID] require.True(t, ok) require.NotEmpty(t, data) outPath := filepath.Clean(filepath.Join("..", "checkpoint_data_v0.8.4.bin")) require.NoError(t, os.WriteFile(outPath, data, 0o644)) } ================================================ FILE: adk/prebuilt/deep/types.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package deep import ( "context" "fmt" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/tool" ) const ( generalAgentName = "general-purpose" taskToolName = "task" ) const ( SessionKeyTodos = "deep_agent_session_key_todos" ) func assertAgentTool(t tool.BaseTool) (tool.InvokableTool, error) { it, ok := t.(tool.InvokableTool) if !ok { return nil, fmt.Errorf("failed to assert agent tool type: %T", t) } return it, nil } func buildAppendPromptTool(prompt string, t tool.BaseTool) adk.ChatModelAgentMiddleware { return &appendPromptTool{ BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, t: t, prompt: prompt, } } type appendPromptTool struct { *adk.BaseChatModelAgentMiddleware t tool.BaseTool prompt string } func (w *appendPromptTool) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { nRunCtx := *runCtx nRunCtx.Instruction += w.prompt if w.t != nil { nRunCtx.Tools = append(nRunCtx.Tools, w.t) } return ctx, &nRunCtx, nil } ================================================ FILE: adk/prebuilt/integration_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prebuilt import ( "context" "fmt" "strings" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/prebuilt/planexecute" "github.com/cloudwego/eino/adk/prebuilt/supervisor" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) type approvalInfo struct { ToolName string ArgumentsInJSON string ToolCallID string } func (ai *approvalInfo) String() string { return fmt.Sprintf("tool '%s' interrupted with arguments '%s', waiting for approval", ai.ToolName, ai.ArgumentsInJSON) } type approvalResult struct { Approved bool DisapproveReason *string } func init() { schema.Register[*approvalInfo]() schema.Register[*approvalResult]() } type approvableTool struct { name string t *testing.T } func (m *approvableTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: m.name, Desc: "A tool that requires approval before execution", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "action": {Type: schema.String, Desc: "The action to perform"}, }), }, nil } func (m *approvableTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { wasInterrupted, _, storedArguments := tool.GetInterruptState[string](ctx) if !wasInterrupted { return "", tool.StatefulInterrupt(ctx, &approvalInfo{ ToolName: m.name, ArgumentsInJSON: argumentsInJSON, ToolCallID: compose.GetToolCallID(ctx), }, argumentsInJSON) } isResumeTarget, hasData, data := tool.GetResumeContext[*approvalResult](ctx) if !isResumeTarget { return "", tool.StatefulInterrupt(ctx, &approvalInfo{ ToolName: m.name, ArgumentsInJSON: storedArguments, ToolCallID: compose.GetToolCallID(ctx), }, storedArguments) } if !hasData { return "", fmt.Errorf("tool '%s' resumed with no data", m.name) } if data.Approved { return fmt.Sprintf("Tool '%s' executed successfully with args: %s", m.name, storedArguments), nil } if data.DisapproveReason != nil { return fmt.Sprintf("Tool '%s' disapproved, reason: %s", m.name, *data.DisapproveReason), nil } return fmt.Sprintf("Tool '%s' disapproved", m.name), nil } type integrationCheckpointStore struct { data map[string][]byte } func newIntegrationCheckpointStore() *integrationCheckpointStore { return &integrationCheckpointStore{data: make(map[string][]byte)} } func (s *integrationCheckpointStore) Set(_ context.Context, key string, value []byte) error { s.data[key] = value return nil } func (s *integrationCheckpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { v, ok := s.data[key] return v, ok, nil } type defaultPlan struct { Steps []string `json:"steps"` } func (p *defaultPlan) FirstStep() string { if len(p.Steps) == 0 { return "" } return p.Steps[0] } func (p *defaultPlan) MarshalJSON() ([]byte, error) { type planTyp defaultPlan return sonic.Marshal((*planTyp)(p)) } func (p *defaultPlan) UnmarshalJSON(bytes []byte) error { type planTyp defaultPlan return sonic.Unmarshal(bytes, (*planTyp)(p)) } type namedAgent struct { adk.ResumableAgent name string description string } func (n *namedAgent) Name(_ context.Context) string { return n.name } func (n *namedAgent) Description(_ context.Context) string { return n.description } func formatRunPath(runPath []adk.RunStep) string { if len(runPath) == 0 { return "[]" } var parts []string for _, step := range runPath { parts = append(parts, step.String()) } return "[" + strings.Join(parts, " -> ") + "]" } func formatAgentEventIntegration(event *adk.AgentEvent) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("{AgentName: %q, RunPath: %s", event.AgentName, formatRunPath(event.RunPath))) if event.Output != nil { if event.Output.MessageOutput != nil && event.Output.MessageOutput.Message != nil { msg := event.Output.MessageOutput.Message sb.WriteString(fmt.Sprintf(", Output.Message: {Role: %q, Content: %q}", msg.Role, msg.Content)) } } if event.Action != nil { if event.Action.Interrupted != nil { sb.WriteString(fmt.Sprintf(", Action.Interrupted: {%d contexts}", len(event.Action.Interrupted.InterruptContexts))) } if event.Action.BreakLoop != nil { sb.WriteString(fmt.Sprintf(", Action.BreakLoop: {From: %q, Done: %v}", event.Action.BreakLoop.From, event.Action.BreakLoop.Done)) } if event.Action.TransferToAgent != nil { sb.WriteString(fmt.Sprintf(", Action.TransferToAgent: {Dest: %q}", event.Action.TransferToAgent.DestAgentName)) } } if event.Err != nil { sb.WriteString(fmt.Sprintf(", Err: %v", event.Err)) } sb.WriteString("}") return sb.String() } func TestSupervisorWithPlanExecuteInterruptResume(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() mockSupervisorModel := mockModel.NewMockToolCallingChatModel(ctrl) mockPlannerModel := mockModel.NewMockToolCallingChatModel(ctrl) mockExecutorModel := mockModel.NewMockToolCallingChatModel(ctrl) mockReplannerModel := mockModel.NewMockToolCallingChatModel(ctrl) budgetTool := &approvableTool{name: "allocate_budget", t: t} plan := &defaultPlan{Steps: []string{"Allocate budget for the project", "Complete task"}} userInput := []adk.Message{schema.UserMessage("Set up a new project with budget allocation")} plannerModelWithTools := mockModel.NewMockToolCallingChatModel(ctrl) mockPlannerModel.EXPECT().WithTools(gomock.Any()).Return(plannerModelWithTools, nil).AnyTimes() planJSON, _ := sonic.MarshalString(plan) plannerResponse := schema.AssistantMessage("", []schema.ToolCall{ { ID: "plan_call_1", Type: "function", Function: schema.FunctionCall{ Name: "plan", Arguments: planJSON, }, }, }) plannerModelWithTools.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input []*schema.Message, opts ...interface{}) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](1) go func() { defer sw.Close() sw.Send(plannerResponse, nil) }() return sr, nil }, ).Times(1) mockExecutorModel.EXPECT().WithTools(gomock.Any()).Return(mockExecutorModel, nil).AnyTimes() toolCallMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_budget_1", Type: "function", Function: schema.FunctionCall{ Name: "allocate_budget", Arguments: `{"action": "allocate $50000 for project"}`, }, }, }) mockExecutorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(toolCallMsg, nil).Times(1) completionMsg := schema.AssistantMessage("Budget allocated successfully", nil) mockExecutorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(completionMsg, nil).AnyTimes() replannerModelWithTools := mockModel.NewMockToolCallingChatModel(ctrl) mockReplannerModel.EXPECT().WithTools(gomock.Any()).Return(replannerModelWithTools, nil).AnyTimes() respondResponse := schema.AssistantMessage("", []schema.ToolCall{ { ID: "respond_call_1", Type: "function", Function: schema.FunctionCall{ Name: "respond", Arguments: `{"response":"Project setup completed with budget allocation"}`, }, }, }) replannerModelWithTools.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input []*schema.Message, opts ...interface{}) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](1) go func() { defer sw.Close() sw.Send(respondResponse, nil) }() return sr, nil }, ).AnyTimes() plannerAgent, err := planexecute.NewPlanner(ctx, &planexecute.PlannerConfig{ ToolCallingChatModel: mockPlannerModel, }) assert.NoError(t, err) executorAgent, err := planexecute.NewExecutor(ctx, &planexecute.ExecutorConfig{ Model: mockExecutorModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{budgetTool}, }, }, }) assert.NoError(t, err) replannerAgent, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{ ChatModel: mockReplannerModel, }) assert.NoError(t, err) planExecuteAgent, err := planexecute.New(ctx, &planexecute.Config{ Planner: plannerAgent, Executor: executorAgent, Replanner: replannerAgent, MaxIterations: 10, }) assert.NoError(t, err) projectAgent := &namedAgent{ ResumableAgent: planExecuteAgent, name: "project_execution_agent", description: "the agent responsible for complex project execution tasks", } var pa adk.Agent pa = projectAgent _, ok := pa.(adk.ResumableAgent) assert.True(t, ok) mockSupervisorModel.EXPECT().WithTools(gomock.Any()).Return(mockSupervisorModel, nil).AnyTimes() transferToProjectMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "transfer_call_1", Type: "function", Function: schema.FunctionCall{ Name: "transfer_to_agent", Arguments: `{"agent_name":"project_execution_agent"}`, }, }, }) mockSupervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(transferToProjectMsg, nil).Times(1) finalSupervisorMsg := schema.AssistantMessage("Project setup completed successfully with budget allocation approved.", nil) mockSupervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(finalSupervisorMsg, nil).AnyTimes() supervisorChatAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "project_manager", Description: "the supervisor agent responsible for coordinating project management tasks", Instruction: "You are a project manager supervisor. Delegate complex project tasks to project_execution_agent.", Model: mockSupervisorModel, Exit: &adk.ExitTool{}, }) assert.NoError(t, err) supervisorAgent, err := supervisor.New(ctx, &supervisor.Config{ Supervisor: supervisorChatAgent, SubAgents: []adk.Agent{projectAgent}, }) assert.NoError(t, err) store := newIntegrationCheckpointStore() runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: supervisorAgent, CheckPointStore: store, }) t.Log("========================================") t.Log("Starting Supervisor + PlanExecute Integration Test") t.Log("========================================") checkpointID := "test-supervisor-plan_execute-1" iter := runner.Run(ctx, userInput, adk.WithCheckPointID(checkpointID)) var interruptEvent *adk.AgentEvent eventCount := 0 for { event, ok := iter.Next() if !ok { break } eventCount++ t.Logf("Event %d: %s", eventCount, formatAgentEventIntegration(event)) if event.Err != nil { t.Logf("Event has error: %v", event.Err) } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event t.Log("========================================") t.Log("INTERRUPT DETECTED - Deep interrupt from tool within executor") t.Log("========================================") break } } if interruptEvent == nil { t.Fatal("Expected an interrupt event from the approvable tool, but none was received") } assert.NotNil(t, interruptEvent.Action.Interrupted, "Should have interrupt info") assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts, "Should have interrupt contexts") t.Logf("Interrupt event received with %d contexts", len(interruptEvent.Action.Interrupted.InterruptContexts)) for i, ctx := range interruptEvent.Action.Interrupted.InterruptContexts { t.Logf("Interrupt context %d: ID=%s, Info=%v, IsRootCause=%v", i, ctx.ID, ctx.Info, ctx.IsRootCause) } var toolInterruptID string for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { if intCtx.IsRootCause { toolInterruptID = intCtx.ID break } } assert.NotEmpty(t, toolInterruptID, "Should have a root cause interrupt ID") t.Log("========================================") t.Logf("Resuming with approval for interrupt ID: %s", toolInterruptID) t.Log("========================================") resumeIter, err := runner.ResumeWithParams(ctx, checkpointID, &adk.ResumeParams{ Targets: map[string]any{ toolInterruptID: &approvalResult{Approved: true}, }, }) assert.NoError(t, err, "Resume should not error") assert.NotNil(t, resumeIter, "Resume iterator should not be nil") var resumeEvents []*adk.AgentEvent for { event, ok := resumeIter.Next() if !ok { break } resumeEvents = append(resumeEvents, event) } assert.NotEmpty(t, resumeEvents, "Should have resume events after approval") for _, event := range resumeEvents { assert.NoError(t, event.Err, "Resume event should not have error") } var hasToolResponse, hasBreakLoop bool for _, event := range resumeEvents { if event.Output != nil && event.Output.MessageOutput != nil { msg := event.Output.MessageOutput.Message if msg != nil && msg.Role == "tool" && strings.Contains(msg.Content, "executed successfully") { hasToolResponse = true } } if event.Action != nil && event.Action.BreakLoop != nil && event.Action.BreakLoop.Done { hasBreakLoop = true } } assert.True(t, hasToolResponse, "Should have tool response indicating successful execution") assert.True(t, hasBreakLoop, "Should have break loop action indicating task completion") } ================================================ FILE: adk/prebuilt/planexecute/plan_execute.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package planexecute implements a plan–execute–replan style agent. package planexecute import ( "context" "encoding/json" "fmt" "runtime/debug" "strings" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) func init() { schema.RegisterName[*defaultPlan]("_eino_adk_plan_execute_default_plan") schema.RegisterName[ExecutedStep]("_eino_adk_plan_execute_executed_step") schema.RegisterName[[]ExecutedStep]("_eino_adk_plan_execute_executed_steps") } // Plan represents an execution plan with a sequence of actionable steps. // It supports JSON serialization and deserialization while providing access to the first step. type Plan interface { // FirstStep returns the first step to be executed in the plan. FirstStep() string // Marshaler serializes the Plan into JSON. // The resulting JSON can be used in prompt templates. json.Marshaler // Unmarshaler deserializes JSON content into the Plan. // This processes output from structured chat models or tool calls into the Plan structure. json.Unmarshaler } // NewPlan is a function type that creates a new Plan instance. type NewPlan func(ctx context.Context) Plan // defaultPlan is the default implementation of the Plan interface. // // JSON Schema: // // { // "type": "object", // "properties": { // "steps": { // "type": "array", // "items": { // "type": "string" // }, // "description": "Ordered list of actions to be taken. Each step should be clear, actionable, and arranged in a logical sequence." // } // }, // "required": ["steps"] // } type defaultPlan struct { // Steps contains the ordered list of actions to be taken. // Each step should be clear, actionable, and arranged in a logical sequence. Steps []string `json:"steps"` } // FirstStep returns the first step in the plan or an empty string if no steps exist. func (p *defaultPlan) FirstStep() string { if len(p.Steps) == 0 { return "" } return p.Steps[0] } func (p *defaultPlan) MarshalJSON() ([]byte, error) { type planTyp defaultPlan return sonic.Marshal((*planTyp)(p)) } func (p *defaultPlan) UnmarshalJSON(bytes []byte) error { type planTyp defaultPlan return sonic.Unmarshal(bytes, (*planTyp)(p)) } // Response represents the final response to the user. // This struct is used for JSON serialization/deserialization of the final response // generated by the model. type Response struct { // Response is the complete response to provide to the user. // This field is required. Response string `json:"response"` } var ( // PlanToolInfo defines the schema for the Plan tool that can be used with ToolCallingChatModel. // This schema instructs the model to generate a structured plan with ordered steps. PlanToolInfo = schema.ToolInfo{ Name: "plan", Desc: "Plan with a list of steps to execute in order. Each step should be clear, actionable, and arranged in a logical sequence. The output will be used to guide the execution process.", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "steps": { Type: schema.Array, ElemInfo: &schema.ParameterInfo{Type: schema.String}, Desc: "different steps to follow, should be in sorted order", Required: true, }, }, ), } // RespondToolInfo defines the schema for the response tool that can be used with ToolCallingChatModel. // This schema instructs the model to generate a direct response to the user. RespondToolInfo = schema.ToolInfo{ Name: "respond", Desc: "Generate a direct response to the user. Use this tool when you have all the information needed to provide a final answer.", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "response": { Type: schema.String, Desc: "The complete response to provide to the user", Required: true, }, }, ), } // PlannerPrompt is the prompt template for the planner. // It provides context and guidance to the planner on how to generate the Plan. PlannerPrompt = prompt.FromMessages(schema.FString, schema.SystemMessage(`You are an expert planning agent. Given an objective, create a comprehensive step-by-step plan to achieve the objective. ## YOUR TASK Analyze the objective and generate a strategic plan that breaks down the goal into manageable, executable steps. ## PLANNING REQUIREMENTS Each step in your plan must be: - **Specific and actionable**: Clear instructions that can be executed without ambiguity - **Self-contained**: Include all necessary context, parameters, and requirements - **Independently executable**: Can be performed by an agent without dependencies on other steps - **Logically sequenced**: Arranged in optimal order for efficient execution - **Objective-focused**: Directly contribute to achieving the main goal ## PLANNING GUIDELINES - Eliminate redundant or unnecessary steps - Include relevant constraints, parameters, and success criteria for each step - Ensure the final step produces a complete answer or deliverable - Anticipate potential challenges and include mitigation strategies - Structure steps to build upon each other logically - Provide sufficient detail for successful execution ## QUALITY CRITERIA - Plan completeness: Does it address all aspects of the objective? - Step clarity: Can each step be understood and executed independently? - Logical flow: Do steps follow a sensible progression? - Efficiency: Is this the most direct path to the objective? - Adaptability: Can the plan handle unexpected results or changes?`), schema.MessagesPlaceholder("input", false), ) // ExecutorPrompt is the prompt template for the executor. // It provides context and guidance to the executor on how to execute the Task. ExecutorPrompt = prompt.FromMessages(schema.FString, schema.SystemMessage(`You are a diligent and meticulous executor agent. Follow the given plan and execute your tasks carefully and thoroughly.`), schema.UserMessage(`## OBJECTIVE {input} ## Given the following plan: {plan} ## COMPLETED STEPS & RESULTS {executed_steps} ## Your task is to execute the first step, which is: {step}`)) // ReplannerPrompt is the prompt template for the replanner. // It provides context and guidance to the replanner on how to regenerate the Plan. ReplannerPrompt = prompt.FromMessages(schema.FString, schema.SystemMessage( `You are going to review the progress toward an objective. Analyze the current state and determine the optimal next action. ## YOUR TASK Based on the progress above, you MUST choose exactly ONE action: ### Option 1: COMPLETE (if objective is fully achieved) Call '{respond_tool}' with: - A comprehensive final answer - Clear conclusion summarizing how the objective was met - Key insights from the execution process ### Option 2: CONTINUE (if more work is needed) Call '{plan_tool}' with a revised plan that: - Contains ONLY remaining steps (exclude completed ones) - Incorporates lessons learned from executed steps - Addresses any gaps or issues discovered - Maintains logical step sequence ## PLANNING REQUIREMENTS Each step in your plan must be: - **Specific and actionable**: Clear instructions that can be executed without ambiguity - **Self-contained**: Include all necessary context, parameters, and requirements - **Independently executable**: Can be performed by an agent without dependencies on other steps - **Logically sequenced**: Arranged in optimal order for efficient execution - **Objective-focused**: Directly contribute to achieving the main goal ## PLANNING GUIDELINES - Eliminate redundant or unnecessary steps - Adapt strategy based on new information - Include relevant constraints, parameters, and success criteria for each step ## DECISION CRITERIA - Has the original objective been completely satisfied? - Are there any remaining requirements or sub-goals? - Do the results suggest a need for strategy adjustment? - What specific actions are still required?`), schema.UserMessage(`## OBJECTIVE {input} ## ORIGINAL PLAN {plan} ## COMPLETED STEPS & RESULTS {executed_steps}`), ) ) const ( // UserInputSessionKey is the session key for the user input. UserInputSessionKey = "UserInput" // PlanSessionKey is the session key for the plan. PlanSessionKey = "Plan" // ExecutedStepSessionKey is the session key for the execute result. ExecutedStepSessionKey = "ExecutedStep" // ExecutedStepsSessionKey is the session key for the execute results. ExecutedStepsSessionKey = "ExecutedSteps" ) // PlannerConfig provides configuration options for creating a planner agent. // There are two ways to configure the planner to generate structured Plan output: // 1. Use ChatModelWithFormattedOutput: A model pre-configured to output in the Plan format // 2. Use ToolCallingChatModel + ToolInfo: A model that uses tool calling to generate // the Plan structure type PlannerConfig struct { // ChatModelWithFormattedOutput is a model pre-configured to output in the Plan format. // Create this by configuring a model to output structured data directly. // See example: https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go ChatModelWithFormattedOutput model.BaseChatModel // ToolCallingChatModel is a model that supports tool calling capabilities. // When provided with ToolInfo, it will use tool calling to generate the Plan structure. ToolCallingChatModel model.ToolCallingChatModel // ToolInfo defines the schema for the Plan structure when using tool calling. // Optional. If not provided, PlanToolInfo will be used as the default. ToolInfo *schema.ToolInfo // GenInputFn is a function that generates the input messages for the planner. // Optional. If not provided, defaultGenPlannerInputFn will be used. GenInputFn GenPlannerModelInputFn // NewPlan creates a new Plan instance for JSON. // The returned Plan will be used to unmarshal the model-generated JSON output. // Optional. If not provided, defaultNewPlan will be used. NewPlan NewPlan } // GenPlannerModelInputFn is a function type that generates input messages for the planner. type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) func defaultNewPlan(ctx context.Context) Plan { return &defaultPlan{} } func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) { msgs, err := PlannerPrompt.Format(ctx, map[string]any{ "input": userInput, }) if err != nil { return nil, err } return msgs, nil } type planner struct { toolCall bool chatModel model.BaseChatModel genInputFn GenPlannerModelInputFn newPlan NewPlan } func (p *planner) Name(_ context.Context) string { return "planner" } func (p *planner) Description(_ context.Context) string { return "a planner agent" } func argToContent(msg adk.Message) (adk.Message, error) { if len(msg.ToolCalls) == 0 { return nil, schema.ErrNoValue } return schema.AssistantMessage(msg.ToolCalls[0].Function.Arguments, nil), nil } func (p *planner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() adk.AddSessionValue(ctx, UserInputSessionKey, input.Messages) go func() { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&adk.AgentEvent{Err: e}) } generator.Close() }() c := compose.NewChain[*adk.AgentInput, Plan](). AppendLambda( compose.InvokableLambda(func(ctx context.Context, input *adk.AgentInput) (output []adk.Message, err error) { return p.genInputFn(ctx, input.Messages) }), ). AppendChatModel(p.chatModel). AppendLambda( compose.CollectableLambda(func(ctx context.Context, sr *schema.StreamReader[adk.Message]) (adk.Message, error) { if input.EnableStreaming { ss := sr.Copy(2) var sOutput *schema.StreamReader[*schema.Message] if p.toolCall { sOutput = schema.StreamReaderWithConvert(ss[0], argToContent) } else { sOutput = ss[0] } generator.Send(adk.EventFromMessage(nil, sOutput, schema.Assistant, "")) return schema.ConcatMessageStream(ss[1]) } msg, err := schema.ConcatMessageStream(sr) if err != nil { return nil, err } var output adk.Message if p.toolCall { if len(msg.ToolCalls) == 0 { return nil, fmt.Errorf("no tool call") } output = schema.AssistantMessage(msg.ToolCalls[0].Function.Arguments, nil) } else { output = msg } generator.Send(adk.EventFromMessage(output, nil, schema.Assistant, "")) return msg, nil }), ). AppendLambda( compose.InvokableLambda(func(ctx context.Context, msg adk.Message) (plan Plan, err error) { var planJSON string if p.toolCall { if len(msg.ToolCalls) == 0 { return nil, fmt.Errorf("no tool call") } planJSON = msg.ToolCalls[0].Function.Arguments } else { planJSON = msg.Content } plan = p.newPlan(ctx) err = plan.UnmarshalJSON([]byte(planJSON)) if err != nil { return nil, fmt.Errorf("unmarshal plan error: %w", err) } adk.AddSessionValue(ctx, PlanSessionKey, plan) return plan, nil }), ) var opts []compose.Option if p.toolCall { opts = append(opts, compose.WithChatModelOption(model.WithToolChoice(schema.ToolChoiceForced))) } r, err := c.Compile(ctx, compose.WithGraphName(p.Name(ctx))) if err != nil { // unexpected generator.Send(&adk.AgentEvent{Err: err}) return } _, err = r.Stream(ctx, input, opts...) if err != nil { generator.Send(&adk.AgentEvent{Err: err}) return } }() return iterator } // NewPlanner creates a new planner agent based on the provided configuration. // The planner agent uses either ChatModelWithFormattedOutput or ToolCallingChatModel+ToolInfo // to generate structured Plan output. // // If ChatModelWithFormattedOutput is provided, it will be used directly. // If ToolCallingChatModel is provided, it will be configured with ToolInfo (or PlanToolInfo by default) // to generate structured Plan output. func NewPlanner(_ context.Context, cfg *PlannerConfig) (adk.Agent, error) { var chatModel model.BaseChatModel var toolCall bool if cfg.ChatModelWithFormattedOutput != nil { chatModel = cfg.ChatModelWithFormattedOutput } else { toolCall = true toolInfo := cfg.ToolInfo if toolInfo == nil { toolInfo = &PlanToolInfo } var err error chatModel, err = cfg.ToolCallingChatModel.WithTools([]*schema.ToolInfo{toolInfo}) if err != nil { return nil, err } } inputFn := cfg.GenInputFn if inputFn == nil { inputFn = defaultGenPlannerInputFn } planParser := cfg.NewPlan if planParser == nil { planParser = defaultNewPlan } return &planner{ toolCall: toolCall, chatModel: chatModel, genInputFn: inputFn, newPlan: planParser, }, nil } // ExecutionContext is the input information for the executor and the planner. type ExecutionContext struct { UserInput []adk.Message Plan Plan ExecutedSteps []ExecutedStep } // GenModelInputFn is a function that generates the input messages for the executor and the planner. type GenModelInputFn func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) // ExecutorConfig provides configuration options for creating an executor agent. type ExecutorConfig struct { // Model is the chat model used by the executor. // If the executor uses any tools, this model must support the model.WithTools call option, // as that's how the executor configures the model with tool information. Model model.BaseChatModel // ToolsConfig specifies the tools available to the executor. ToolsConfig adk.ToolsConfig // MaxIterations defines the upper limit of ChatModel generation cycles. // The agent will terminate with an error if this limit is exceeded. // Optional. Defaults to 20. MaxIterations int // GenInputFn generates the input messages for the Executor. // Optional. If not provided, defaultGenExecutorInputFn will be used. GenInputFn GenModelInputFn } type ExecutedStep struct { Step string Result string } // NewExecutor creates a new executor agent. func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) { genInputFn := cfg.GenInputFn if genInputFn == nil { genInputFn = defaultGenExecutorInputFn } genInput := func(ctx context.Context, instruction string, _ *adk.AgentInput) ([]adk.Message, error) { plan, ok := adk.GetSessionValue(ctx, PlanSessionKey) if !ok { panic("impossible: plan not found") } plan_ := plan.(Plan) userInput, ok := adk.GetSessionValue(ctx, UserInputSessionKey) if !ok { panic("impossible: user input not found") } userInput_ := userInput.([]adk.Message) var executedSteps_ []ExecutedStep executedStep, ok := adk.GetSessionValue(ctx, ExecutedStepsSessionKey) if ok { executedSteps_ = executedStep.([]ExecutedStep) } in := &ExecutionContext{ UserInput: userInput_, Plan: plan_, ExecutedSteps: executedSteps_, } msgs, err := genInputFn(ctx, in) if err != nil { return nil, err } return msgs, nil } agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "executor", Description: "an executor agent", Model: cfg.Model, ToolsConfig: cfg.ToolsConfig, GenModelInput: genInput, MaxIterations: cfg.MaxIterations, OutputKey: ExecutedStepSessionKey, }) if err != nil { return nil, err } return agent, nil } func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) { planContent, err := in.Plan.MarshalJSON() if err != nil { return nil, err } userMsgs, err := ExecutorPrompt.Format(ctx, map[string]any{ "input": formatInput(in.UserInput), "plan": string(planContent), "executed_steps": formatExecutedSteps(in.ExecutedSteps), "step": in.Plan.FirstStep(), }) if err != nil { return nil, err } return userMsgs, nil } type replanner struct { chatModel model.ToolCallingChatModel planTool *schema.ToolInfo respondTool *schema.ToolInfo genInputFn GenModelInputFn newPlan NewPlan } type ReplannerConfig struct { // ChatModel is the model that supports tool calling capabilities. // It will be configured with PlanTool and RespondTool to generate updated plans or responses. ChatModel model.ToolCallingChatModel // PlanTool defines the schema for the Plan tool that can be used with ToolCallingChatModel. // Optional. If not provided, the default PlanToolInfo will be used. PlanTool *schema.ToolInfo // RespondTool defines the schema for the response tool that can be used with ToolCallingChatModel. // Optional. If not provided, the default RespondToolInfo will be used. RespondTool *schema.ToolInfo // GenInputFn generates the input messages for the Replanner. // Optional. If not provided, buildGenReplannerInputFn will be used. GenInputFn GenModelInputFn // NewPlan creates a new Plan instance. // The returned Plan will be used to unmarshal the model-generated JSON output from PlanTool. // Optional. If not provided, defaultNewPlan will be used. NewPlan NewPlan } // formatInput formats the input messages into a string. func formatInput(input []adk.Message) string { var sb strings.Builder for _, msg := range input { sb.WriteString(msg.Content) sb.WriteString("\n") } return sb.String() } func formatExecutedSteps(results []ExecutedStep) string { var sb strings.Builder for _, result := range results { sb.WriteString(fmt.Sprintf("Step: %s\nResult: %s\n\n", result.Step, result.Result)) } return sb.String() } func (r *replanner) Name(_ context.Context) string { return "replanner" } func (r *replanner) Description(_ context.Context) string { return "a replanner agent" } func (r *replanner) genInput(ctx context.Context) ([]adk.Message, error) { executedStep, ok := adk.GetSessionValue(ctx, ExecutedStepSessionKey) if !ok { panic("impossible: execute result not found") } executedStep_ := executedStep.(string) plan, ok := adk.GetSessionValue(ctx, PlanSessionKey) if !ok { panic("impossible: plan not found") } plan_ := plan.(Plan) step := plan_.FirstStep() var executedSteps_ []ExecutedStep executedSteps, ok := adk.GetSessionValue(ctx, ExecutedStepsSessionKey) if ok { executedSteps_ = executedSteps.([]ExecutedStep) } executedSteps_ = append(executedSteps_, ExecutedStep{ Step: step, Result: executedStep_, }) adk.AddSessionValue(ctx, ExecutedStepsSessionKey, executedSteps_) userInput, ok := adk.GetSessionValue(ctx, UserInputSessionKey) if !ok { panic("impossible: user input not found") } userInput_ := userInput.([]adk.Message) in := &ExecutionContext{ UserInput: userInput_, Plan: plan_, ExecutedSteps: executedSteps_, } genInputFn := r.genInputFn if genInputFn == nil { genInputFn = buildGenReplannerInputFn(r.planTool.Name, r.respondTool.Name) } msgs, err := genInputFn(ctx, in) if err != nil { return nil, err } return msgs, nil } func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() go func() { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&adk.AgentEvent{Err: e}) } generator.Close() }() callOpt := model.WithToolChoice(schema.ToolChoiceForced) c := compose.NewChain[struct{}, any](). AppendLambda( compose.InvokableLambda(func(ctx context.Context, input struct{}) (output []adk.Message, err error) { return r.genInput(ctx) }), ). AppendChatModel(r.chatModel). AppendLambda( compose.CollectableLambda(func(ctx context.Context, sr *schema.StreamReader[adk.Message]) (adk.Message, error) { if input.EnableStreaming { ss := sr.Copy(2) sOutput := schema.StreamReaderWithConvert(ss[0], argToContent) generator.Send(adk.EventFromMessage(nil, sOutput, schema.Assistant, "")) return schema.ConcatMessageStream(ss[1]) } msg, err := schema.ConcatMessageStream(sr) if err != nil { return nil, err } if len(msg.ToolCalls) > 0 { output := schema.AssistantMessage(msg.ToolCalls[0].Function.Arguments, nil) generator.Send(adk.EventFromMessage(output, nil, schema.Assistant, "")) } return msg, nil }), ). AppendLambda( compose.InvokableLambda(func(ctx context.Context, msg adk.Message) (msgOrPlan any, err error) { if len(msg.ToolCalls) == 0 { return nil, fmt.Errorf("no tool call") } // exit if msg.ToolCalls[0].Function.Name == r.respondTool.Name { action := adk.NewBreakLoopAction(r.Name(ctx)) generator.Send(&adk.AgentEvent{Action: action}) return msg, nil } // replan if msg.ToolCalls[0].Function.Name != r.planTool.Name { return nil, fmt.Errorf("unexpected tool call: %s", msg.ToolCalls[0].Function.Name) } plan := r.newPlan(ctx) if err = plan.UnmarshalJSON([]byte(msg.ToolCalls[0].Function.Arguments)); err != nil { return nil, fmt.Errorf("unmarshal plan error: %w", err) } adk.AddSessionValue(ctx, PlanSessionKey, plan) return plan, nil }), ) runnable, err := c.Compile(ctx, compose.WithGraphName(r.Name(ctx))) if err != nil { generator.Send(&adk.AgentEvent{Err: err}) return } _, err = runnable.Stream(ctx, struct{}{}, compose.WithChatModelOption(callOpt)) if err != nil { generator.Send(&adk.AgentEvent{Err: err}) return } }() return iterator } func buildGenReplannerInputFn(planToolName, respondToolName string) GenModelInputFn { return func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) { planContent, err := in.Plan.MarshalJSON() if err != nil { return nil, err } msgs, err := ReplannerPrompt.Format(ctx, map[string]any{ "plan": string(planContent), "input": formatInput(in.UserInput), "executed_steps": formatExecutedSteps(in.ExecutedSteps), "plan_tool": planToolName, "respond_tool": respondToolName, }) if err != nil { return nil, err } return msgs, nil } } // NewReplanner creates a plan-execute-replan agent wired with plan and respond tools. // It configures the provided ToolCallingChatModel with the tools and returns an Agent. func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) { planTool := cfg.PlanTool if planTool == nil { planTool = &PlanToolInfo } respondTool := cfg.RespondTool if respondTool == nil { respondTool = &RespondToolInfo } chatModel, err := cfg.ChatModel.WithTools([]*schema.ToolInfo{planTool, respondTool}) if err != nil { return nil, err } planParser := cfg.NewPlan if planParser == nil { planParser = defaultNewPlan } return &replanner{ chatModel: chatModel, planTool: planTool, respondTool: respondTool, genInputFn: cfg.GenInputFn, newPlan: planParser, }, nil } // Config provides configuration options for creating a plan-execute-replan agent. type Config struct { // Planner specifies the agent that generates the plan. // You can use provided NewPlanner to create a planner agent. Planner adk.Agent // Executor specifies the agent that executes the plan generated by planner or replanner. // You can use provided NewExecutor to create an executor agent. Executor adk.Agent // Replanner specifies the agent that replans the plan. // You can use provided NewReplanner to create a replanner agent. Replanner adk.Agent // MaxIterations defines the maximum number of loops for 'execute-replan'. // Optional. If not provided, 10 will be used as the default. MaxIterations int } // New creates a new plan-execute-replan agent with the given configuration. // The plan-execute-replan pattern works in three phases: // 1. Planning: Generate a structured plan with clear, actionable steps // 2. Execution: Execute the first step of the plan // 3. Replanning: Evaluate progress and either complete the task or revise the plan // This approach enables complex problem-solving through iterative refinement. func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { maxIterations := cfg.MaxIterations if maxIterations <= 0 { maxIterations = 10 } loop, err := adk.NewLoopAgent(ctx, &adk.LoopAgentConfig{ Name: "execute_replan", SubAgents: []adk.Agent{cfg.Executor, cfg.Replanner}, MaxIterations: maxIterations, }) if err != nil { return nil, err } return adk.NewSequentialAgent(ctx, &adk.SequentialAgentConfig{ Name: "plan_execute_replan", SubAgents: []adk.Agent{cfg.Planner, loop}, }) } ================================================ FILE: adk/prebuilt/planexecute/plan_execute_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package planexecute import ( "context" "fmt" "strings" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockAdk "github.com/cloudwego/eino/internal/mock/adk" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) // TestNewPlanner tests the NewPlanner function with ChatModelWithFormattedOutput func TestNewPlannerWithFormattedOutput(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock chat model mockChatModel := mockModel.NewMockBaseChatModel(ctrl) // Create the PlannerConfig conf := &PlannerConfig{ ChatModelWithFormattedOutput: mockChatModel, } // Create the planner p, err := NewPlanner(ctx, conf) assert.NoError(t, err) assert.NotNil(t, p) // Verify the planner's name and description assert.Equal(t, "planner", p.Name(ctx)) assert.Equal(t, "a planner agent", p.Description(ctx)) } // TestNewPlannerWithToolCalling tests the NewPlanner function with ToolCallingChatModel func TestNewPlannerWithToolCalling(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock tool calling chat model mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) // Create the PlannerConfig conf := &PlannerConfig{ ToolCallingChatModel: mockToolCallingModel, // Use default instruction and tool info } // Create the planner p, err := NewPlanner(ctx, conf) assert.NoError(t, err) assert.NotNil(t, p) // Verify the planner's name and description assert.Equal(t, "planner", p.Name(ctx)) assert.Equal(t, "a planner agent", p.Description(ctx)) } // TestPlannerRunWithFormattedOutput tests the Run method of a planner created with ChatModelWithFormattedOutput func TestPlannerRunWithFormattedOutput(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock chat model mockChatModel := mockModel.NewMockBaseChatModel(ctrl) // Create a plan response planJSON := `{"steps":["Step 1", "Step 2", "Step 3"]}` planMsg := schema.AssistantMessage(planJSON, nil) sr, sw := schema.Pipe[*schema.Message](1) sw.Send(planMsg, nil) sw.Close() // Mock the Generate method mockChatModel.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) // Create the PlannerConfig conf := &PlannerConfig{ ChatModelWithFormattedOutput: mockChatModel, } // Create the planner p, err := NewPlanner(ctx, conf) assert.NoError(t, err) // Run the planner runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: p}) iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("Plan this task")}) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Nil(t, event.Err) msg, _, err := adk.GetMessage(event) assert.NoError(t, err) assert.Equal(t, planMsg.Content, msg.Content) event, ok = iterator.Next() assert.False(t, ok) plan := defaultNewPlan(ctx) err = plan.UnmarshalJSON([]byte(msg.Content)) assert.NoError(t, err) plan_ := plan.(*defaultPlan) assert.Equal(t, 3, len(plan_.Steps)) assert.Equal(t, "Step 1", plan_.Steps[0]) assert.Equal(t, "Step 2", plan_.Steps[1]) assert.Equal(t, "Step 3", plan_.Steps[2]) } // TestPlannerRunWithToolCalling tests the Run method of a planner created with ToolCallingChatModel func TestPlannerRunWithToolCalling(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock tool calling chat model mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) // Create a tool call response with a plan planArgs := `{"steps":["Step 1", "Step 2", "Step 3"]}` toolCall := schema.ToolCall{ ID: "tool_call_id", Type: "function", Function: schema.FunctionCall{ Name: "plan", // This should match PlanToolInfo.Name Arguments: planArgs, }, } toolCallMsg := schema.AssistantMessage("", nil) toolCallMsg.ToolCalls = []schema.ToolCall{toolCall} sr, sw := schema.Pipe[*schema.Message](1) sw.Send(toolCallMsg, nil) sw.Close() // Mock the WithTools method to return a model that will be used for Generate mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) // Mock the Generate method to return the tool call message mockToolCallingModel.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) // Create the PlannerConfig with ToolCallingChatModel conf := &PlannerConfig{ ToolCallingChatModel: mockToolCallingModel, // Use default instruction and tool info } // Create the planner p, err := NewPlanner(ctx, conf) assert.NoError(t, err) // Run the planner runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: p}) iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Nil(t, event.Err) msg, _, err := adk.GetMessage(event) assert.NoError(t, err) assert.Equal(t, planArgs, msg.Content) _, ok = iterator.Next() assert.False(t, ok) plan := defaultNewPlan(ctx) err = plan.UnmarshalJSON([]byte(msg.Content)) assert.NoError(t, err) plan_ := plan.(*defaultPlan) assert.NoError(t, err) assert.Equal(t, 3, len(plan_.Steps)) assert.Equal(t, "Step 1", plan_.Steps[0]) assert.Equal(t, "Step 2", plan_.Steps[1]) assert.Equal(t, "Step 3", plan_.Steps[2]) } // TestNewExecutor tests the NewExecutor function func TestNewExecutor(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock tool calling chat model mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) // Create the ExecutorConfig conf := &ExecutorConfig{ Model: mockToolCallingModel, MaxIterations: 3, } // Create the executor executor, err := NewExecutor(ctx, conf) assert.NoError(t, err) assert.NotNil(t, executor) // Verify the executor's name and description assert.Equal(t, "executor", executor.Name(ctx)) assert.Equal(t, "an executor agent", executor.Description(ctx)) } // TestExecutorRun tests the Run method of the executor func TestExecutorRun(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock tool calling chat model mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) // Store a plan in the session plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} adk.AddSessionValue(ctx, PlanSessionKey, plan) // Set up expectations for the mock model // The model should return the last user message as its response mockToolCallingModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, messages []*schema.Message, opts ...model.Option) (*schema.Message, error) { // Find the last user message var lastUserMessage string for _, msg := range messages { if msg.Role == schema.User { lastUserMessage = msg.Content } } // Return the last user message as the model's response return schema.AssistantMessage(lastUserMessage, nil), nil }).Times(1) // Create the ExecutorConfig conf := &ExecutorConfig{ Model: mockToolCallingModel, MaxIterations: 3, } // Create the executor executor, err := NewExecutor(ctx, conf) assert.NoError(t, err) // Run the executor runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: executor}) iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}, adk.WithSessionValues(map[string]any{ PlanSessionKey: plan, UserInputSessionKey: []adk.Message{schema.UserMessage("no input")}, }), ) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) msg, _, err := adk.GetMessage(event) assert.NoError(t, err) t.Logf("executor model input msg:\n %s\n", msg.Content) _, ok = iterator.Next() assert.False(t, ok) } // TestNewReplanner tests the NewReplanner function func TestNewReplanner(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock tool calling chat model mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) // Mock the WithTools method mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) // Create plan and respond tools planTool := &schema.ToolInfo{ Name: "Plan", Desc: "Plan tool", } respondTool := &schema.ToolInfo{ Name: "Respond", Desc: "Respond tool", } // Create the ReplannerConfig conf := &ReplannerConfig{ ChatModel: mockToolCallingModel, PlanTool: planTool, RespondTool: respondTool, } // Create the replanner rp, err := NewReplanner(ctx, conf) assert.NoError(t, err) assert.NotNil(t, rp) // Verify the replanner's name and description assert.Equal(t, "replanner", rp.Name(ctx)) assert.Equal(t, "a replanner agent", rp.Description(ctx)) } // TestReplannerRunWithPlan tests the Replanner's ability to use the plan_tool func TestReplannerRunWithPlan(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock tool calling chat model mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) // Create plan and respond tools planTool := &schema.ToolInfo{ Name: "Plan", Desc: "Plan tool", } respondTool := &schema.ToolInfo{ Name: "Respond", Desc: "Respond tool", } // Create a tool call response for the Plan tool planArgs := `{"steps":["Updated Step 1", "Updated Step 2"]}` toolCall := schema.ToolCall{ ID: "tool_call_id", Type: "function", Function: schema.FunctionCall{ Name: planTool.Name, Arguments: planArgs, }, } toolCallMsg := schema.AssistantMessage("", nil) toolCallMsg.ToolCalls = []schema.ToolCall{toolCall} sr, sw := schema.Pipe[*schema.Message](1) sw.Send(toolCallMsg, nil) sw.Close() // Mock the Generate method mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) mockToolCallingModel.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) // Create the ReplannerConfig conf := &ReplannerConfig{ ChatModel: mockToolCallingModel, PlanTool: planTool, RespondTool: respondTool, } // Create the replanner rp, err := NewReplanner(ctx, conf) assert.NoError(t, err) // Store necessary values in the session plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} rp, err = agentOutputSessionKVs(ctx, rp) assert.NoError(t, err) // Run the replanner runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: rp}) iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}, adk.WithSessionValues(map[string]any{ PlanSessionKey: plan, ExecutedStepSessionKey: "Execution result", UserInputSessionKey: []adk.Message{schema.UserMessage("User input")}, }), ) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Nil(t, event.Err) event, ok = iterator.Next() assert.True(t, ok) kvs := event.Output.CustomizedOutput.(map[string]any) assert.Greater(t, len(kvs), 0) // Verify the updated plan was stored in the session planValue, ok := kvs[PlanSessionKey] assert.True(t, ok) updatedPlan, ok := planValue.(*defaultPlan) assert.True(t, ok) assert.Equal(t, 2, len(updatedPlan.Steps)) assert.Equal(t, "Updated Step 1", updatedPlan.Steps[0]) assert.Equal(t, "Updated Step 2", updatedPlan.Steps[1]) // Verify the execute results were updated executeResultsValue, ok := kvs[ExecutedStepsSessionKey] assert.True(t, ok) executeResults, ok := executeResultsValue.([]ExecutedStep) assert.True(t, ok) assert.Equal(t, 1, len(executeResults)) assert.Equal(t, "Step 1", executeResults[0].Step) assert.Equal(t, "Execution result", executeResults[0].Result) _, ok = iterator.Next() assert.False(t, ok) } // TestReplannerRunWithRespond tests the Replanner's ability to use the respond_tool func TestReplannerRunWithRespond(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create a mock tool calling chat model mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) // Create plan and respond tools planTool := &schema.ToolInfo{ Name: "Plan", Desc: "Plan tool", } respondTool := &schema.ToolInfo{ Name: "Respond", Desc: "Respond tool", } // Create a tool call response for the Respond tool responseArgs := `{"response":"This is the final response to the user"}` toolCall := schema.ToolCall{ ID: "tool_call_id", Type: "function", Function: schema.FunctionCall{ Name: respondTool.Name, Arguments: responseArgs, }, } toolCallMsg := schema.AssistantMessage("", nil) toolCallMsg.ToolCalls = []schema.ToolCall{toolCall} sr, sw := schema.Pipe[*schema.Message](1) sw.Send(toolCallMsg, nil) sw.Close() // Mock the Generate method mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) mockToolCallingModel.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) // Create the ReplannerConfig conf := &ReplannerConfig{ ChatModel: mockToolCallingModel, PlanTool: planTool, RespondTool: respondTool, } // Create the replanner rp, err := NewReplanner(ctx, conf) assert.NoError(t, err) // Store necessary values in the session plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} // Run the replanner runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: rp}) iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}, adk.WithSessionValues(map[string]any{ PlanSessionKey: plan, ExecutedStepSessionKey: "Execution result", UserInputSessionKey: []adk.Message{schema.UserMessage("User input")}, }), ) // Get the event from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Nil(t, event.Err) msg, _, err := adk.GetMessage(event) assert.NoError(t, err) assert.Equal(t, responseArgs, msg.Content) // Verify that an exit action was generated event, ok = iterator.Next() assert.True(t, ok) assert.NotNil(t, event.Action) assert.NotNil(t, event.Action.BreakLoop) assert.False(t, event.Action.BreakLoop.Done) _, ok = iterator.Next() assert.False(t, ok) } // TestNewPlanExecuteAgent tests the New function func TestNewPlanExecuteAgent(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create mock agents mockPlanner := mockAdk.NewMockAgent(ctrl) mockExecutor := mockAdk.NewMockAgent(ctrl) mockReplanner := mockAdk.NewMockAgent(ctrl) // Set up expectations for the mock agents mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() mockExecutor.EXPECT().Name(gomock.Any()).Return("executor").AnyTimes() mockExecutor.EXPECT().Description(gomock.Any()).Return("an executor agent").AnyTimes() mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() conf := &Config{ Planner: mockPlanner, Executor: mockExecutor, Replanner: mockReplanner, } // Create the plan execute agent agent, err := New(ctx, conf) assert.NoError(t, err) assert.NotNil(t, agent) } func TestPlanExecuteAgentWithReplan(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create mock agents mockPlanner := mockAdk.NewMockAgent(ctrl) mockExecutor := mockAdk.NewMockAgent(ctrl) mockReplanner := mockAdk.NewMockAgent(ctrl) // Set up expectations for the mock agents mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() mockExecutor.EXPECT().Name(gomock.Any()).Return("executor").AnyTimes() mockExecutor.EXPECT().Description(gomock.Any()).Return("an executor agent").AnyTimes() mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() // Create a plan originalPlan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} // Create an updated plan with fewer steps (after replanning) updatedPlan := &defaultPlan{Steps: []string{"Updated Step 2", "Updated Step 3"}} // Create execute result originalExecuteResult := "Execution result for Step 1" updatedExecuteResult := "Execution result for Updated Step 2" // Create user input userInput := []adk.Message{schema.UserMessage("User task input")} finalResponse := &Response{Response: "Final response to user after executing all steps"} // Mock the planner Run method to set the original plan mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() // Set the plan in the session adk.AddSessionValue(ctx, PlanSessionKey, originalPlan) adk.AddSessionValue(ctx, UserInputSessionKey, userInput) // Send a message event planJSON, _ := sonic.MarshalString(originalPlan) msg := schema.AssistantMessage(planJSON, nil) event := adk.EventFromMessage(msg, nil, schema.Assistant, "") generator.Send(event) generator.Close() return iterator }, ).Times(1) // Mock the executor Run method to set the execute result mockExecutor.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() plan, _ := adk.GetSessionValue(ctx, PlanSessionKey) currentPlan := plan.(*defaultPlan) var msg adk.Message // Check if this is the first replanning (original plan has 3 steps) if len(currentPlan.Steps) == 3 { msg = schema.AssistantMessage(originalExecuteResult, nil) adk.AddSessionValue(ctx, ExecutedStepSessionKey, originalExecuteResult) } else { msg = schema.AssistantMessage(updatedExecuteResult, nil) adk.AddSessionValue(ctx, ExecutedStepSessionKey, updatedExecuteResult) } event := adk.EventFromMessage(msg, nil, schema.Assistant, "") generator.Send(event) generator.Close() return iterator }, ).Times(2) // Mock the replanner Run method to first update the plan, then respond to user mockReplanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() // First call: Update the plan // Get the current plan from the session plan, _ := adk.GetSessionValue(ctx, PlanSessionKey) currentPlan := plan.(*defaultPlan) // Check if this is the first replanning (original plan has 3 steps) if len(currentPlan.Steps) == 3 { // Send a message event with the updated plan planJSON, _ := sonic.MarshalString(updatedPlan) msg := schema.AssistantMessage(planJSON, nil) event := adk.EventFromMessage(msg, nil, schema.Assistant, "") generator.Send(event) // Set the updated plan & execute result in the session adk.AddSessionValue(ctx, PlanSessionKey, updatedPlan) adk.AddSessionValue(ctx, ExecutedStepsSessionKey, []ExecutedStep{{ Step: currentPlan.Steps[0], Result: originalExecuteResult, }}) } else { // Second call: Respond to user responseJSON, err := sonic.MarshalString(finalResponse) assert.NoError(t, err) msg := schema.AssistantMessage(responseJSON, nil) event := adk.EventFromMessage(msg, nil, schema.Assistant, "") generator.Send(event) // Send exit action action := adk.NewExitAction() generator.Send(&adk.AgentEvent{Action: action}) } generator.Close() return iterator }, ).Times(2) conf := &Config{ Planner: mockPlanner, Executor: mockExecutor, Replanner: mockReplanner, } // Create the plan execute agent agent, err := New(ctx, conf) assert.NoError(t, err) assert.NotNil(t, agent) // Run the agent runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) iterator := runner.Run(ctx, userInput) // Collect all events var events []*adk.AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } // Verify the events assert.Greater(t, len(events), 0) for i, event := range events { eventJSON, e := sonic.MarshalString(event) assert.NoError(t, e) t.Logf("event %d:\n%s", i, eventJSON) } } type interruptibleTool struct { name string t *testing.T } func (m *interruptibleTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: m.name, Desc: "A tool that requires human approval before execution", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "action": { Type: schema.String, Desc: "The action to perform", Required: true, }, }), }, nil } func (m *interruptibleTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { wasInterrupted, _, _ := tool.GetInterruptState[any](ctx) if !wasInterrupted { return "", tool.Interrupt(ctx, fmt.Sprintf("Tool '%s' requires human approval", m.name)) } isResumeTarget, hasData, data := tool.GetResumeContext[string](ctx) if !isResumeTarget { return "", tool.Interrupt(ctx, fmt.Sprintf("Tool '%s' requires human approval", m.name)) } if hasData { return fmt.Sprintf("Approved action executed with data: %s", data), nil } return "Approved action executed", nil } type checkpointStore struct { data map[string][]byte } func newCheckpointStore() *checkpointStore { return &checkpointStore{data: make(map[string][]byte)} } func (s *checkpointStore) Set(_ context.Context, key string, value []byte) error { s.data[key] = value return nil } func (s *checkpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { v, ok := s.data[key] return v, ok, nil } func formatRunPath(runPath []adk.RunStep) string { if len(runPath) == 0 { return "[]" } var parts []string for _, step := range runPath { parts = append(parts, step.String()) } return "[" + strings.Join(parts, " -> ") + "]" } func formatAgentEvent(event *adk.AgentEvent) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("{AgentName: %q, RunPath: %s", event.AgentName, formatRunPath(event.RunPath))) if event.Output != nil { if event.Output.MessageOutput != nil && event.Output.MessageOutput.Message != nil { msg := event.Output.MessageOutput.Message sb.WriteString(fmt.Sprintf(", Output.Message: {Role: %q, Content: %q}", msg.Role, msg.Content)) } } if event.Action != nil { if event.Action.Interrupted != nil { sb.WriteString(fmt.Sprintf(", Action.Interrupted: {%d contexts}", len(event.Action.Interrupted.InterruptContexts))) } if event.Action.BreakLoop != nil { sb.WriteString(fmt.Sprintf(", Action.BreakLoop: {From: %q, Done: %v}", event.Action.BreakLoop.From, event.Action.BreakLoop.Done)) } } if event.Err != nil { sb.WriteString(fmt.Sprintf(", Err: %v", event.Err)) } sb.WriteString("}") return sb.String() } func TestPlanExecuteAgentInterruptResume(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) approvalTool := &interruptibleTool{name: "approve_action", t: t} plan := &defaultPlan{Steps: []string{"Execute action requiring approval", "Complete task"}} userInput := []adk.Message{schema.UserMessage("Please execute the action")} mockPlanner := mockAdk.NewMockAgent(ctrl) mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() adk.AddSessionValue(ctx, PlanSessionKey, plan) adk.AddSessionValue(ctx, UserInputSessionKey, userInput) planJSON, _ := sonic.MarshalString(plan) msg := schema.AssistantMessage(planJSON, nil) event := adk.EventFromMessage(msg, nil, schema.Assistant, "") generator.Send(event) generator.Close() return iterator }, ).Times(1) toolCallMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_1", Type: "function", Function: schema.FunctionCall{ Name: "approve_action", Arguments: `{"action": "execute"}`, }, }, }) completionMsg := schema.AssistantMessage("Action approved and executed successfully", nil) mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).AnyTimes() mockToolCallingModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(toolCallMsg, nil).Times(1) mockToolCallingModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(completionMsg, nil).AnyTimes() executor, err := NewExecutor(ctx, &ExecutorConfig{ Model: mockToolCallingModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{approvalTool}, }, }, MaxIterations: 5, }) assert.NoError(t, err) mockReplanner := mockAdk.NewMockAgent(ctrl) mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() mockReplanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() responseJSON := `{"response":"Task completed successfully"}` msg := schema.AssistantMessage(responseJSON, nil) event := adk.EventFromMessage(msg, nil, schema.Assistant, "") generator.Send(event) action := adk.NewBreakLoopAction("replanner") generator.Send(&adk.AgentEvent{Action: action}) generator.Close() return iterator }, ).AnyTimes() agent, err := New(ctx, &Config{ Planner: mockPlanner, Executor: executor, Replanner: mockReplanner, MaxIterations: 5, }) assert.NoError(t, err) store := newCheckpointStore() runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: agent, CheckPointStore: store, }) iter := runner.Run(ctx, userInput, adk.WithCheckPointID("test-interrupt-1")) var events []*adk.AgentEvent var interruptEvent *adk.AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event } events = append(events, event) } t.Logf("Total events received: %d", len(events)) for i, event := range events { eventJSON, _ := sonic.MarshalString(event) t.Logf("Event %d: %s", i, eventJSON) } if interruptEvent == nil { t.Fatal("Expected an interrupt event from the tool, but none was received") } assert.NotNil(t, interruptEvent.Action.Interrupted, "Should have interrupt info") assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts, "Should have interrupt contexts") t.Logf("Interrupt event received with %d contexts", len(interruptEvent.Action.Interrupted.InterruptContexts)) for i, ctx := range interruptEvent.Action.Interrupted.InterruptContexts { t.Logf("Interrupt context %d: ID=%s, Info=%v, Address=%v", i, ctx.ID, ctx.Info, ctx.Address) } var toolInterruptID string for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { if intCtx.IsRootCause { toolInterruptID = intCtx.ID break } } assert.NotEmpty(t, toolInterruptID, "Should have a root cause interrupt ID") t.Logf("Attempting to resume with interrupt ID: %s", toolInterruptID) resumeIter, err := runner.ResumeWithParams(ctx, "test-interrupt-1", &adk.ResumeParams{ Targets: map[string]any{ toolInterruptID: "approved", }, }) assert.NoError(t, err, "Resume should not error") assert.NotNil(t, resumeIter, "Resume iterator should not be nil") var resumeEvents []*adk.AgentEvent for { event, ok := resumeIter.Next() if !ok { break } resumeEvents = append(resumeEvents, event) } assert.NotEmpty(t, resumeEvents, "Should have resume events") for _, event := range resumeEvents { assert.NoError(t, event.Err, "Resume event should not have error") } var hasToolResponse, hasAssistantCompletion, hasBreakLoop bool for _, event := range resumeEvents { if event.Output != nil && event.Output.MessageOutput != nil { msg := event.Output.MessageOutput.Message if msg != nil { if msg.Role == "tool" && strings.Contains(msg.Content, "Approved action executed") { hasToolResponse = true } if msg.Role == "assistant" && strings.Contains(msg.Content, "approved") { hasAssistantCompletion = true } } } if event.Action != nil && event.Action.BreakLoop != nil && event.Action.BreakLoop.Done { hasBreakLoop = true } } assert.True(t, hasToolResponse, "Should have tool response with approved action") assert.True(t, hasAssistantCompletion, "Should have assistant completion message") assert.True(t, hasBreakLoop, "Should have break loop action indicating completion") } ================================================ FILE: adk/prebuilt/planexecute/utils.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package planexecute import ( "context" "github.com/cloudwego/eino/adk" ) type outputSessionKVsAgent struct { adk.Agent } func (o *outputSessionKVsAgent) Run(ctx context.Context, input *adk.AgentInput, options ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() iterator_ := o.Agent.Run(ctx, input, options...) go func() { defer generator.Close() for { event, ok := iterator_.Next() if !ok { break } generator.Send(event) } kvs := adk.GetSessionValues(ctx) event := &adk.AgentEvent{ Output: &adk.AgentOutput{CustomizedOutput: kvs}, } generator.Send(event) }() return iterator } func agentOutputSessionKVs(ctx context.Context, agent adk.Agent) (adk.Agent, error) { return &outputSessionKVsAgent{Agent: agent}, nil } ================================================ FILE: adk/prebuilt/supervisor/supervisor.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package supervisor implements the supervisor pattern for multi-agent systems, // where a central agent coordinates a set of sub-agents. // // # Unified Tracing // // The supervisor pattern provides unified tracing support through an internal container. // When using callbacks (e.g., for tracing or observability), the entire supervisor structure // (supervisor agent + all sub-agents) shares a single trace root. This means: // - OnStart is invoked once at the supervisor container level // - The callback-enriched context (containing parent span info) is propagated to all agents // - All agents within the supervisor appear as children of the same trace root // // This is achieved by wrapping the supervisor structure in an internal container that acts // as the single entry point for tracing. The container delegates all execution to the // underlying agents while providing a unified identity for callbacks. package supervisor import ( "context" "github.com/cloudwego/eino/adk" ) type Config struct { // Supervisor specifies the agent that will act as the supervisor, coordinating and managing the sub-agents. Supervisor adk.Agent // SubAgents specifies the list of agents that will be supervised and coordinated by the supervisor agent. SubAgents []adk.Agent } // supervisorContainer wraps the entire supervisor structure to provide unified tracing. // When callbacks are registered (e.g., via Runner.Query with WithCallbacks), OnStart/OnEnd // are invoked once for this container, creating a single trace root. The callback-enriched // context is then propagated to the supervisor and all sub-agents, ensuring they share // the same trace parent. // // This container implements Agent and ResumableAgent by delegating to the inner agent. // It provides its own Name and GetType for callback identification. type supervisorContainer struct { name string inner adk.ResumableAgent } func (s *supervisorContainer) Name(_ context.Context) string { return s.name } func (s *supervisorContainer) Description(ctx context.Context) string { return s.inner.Description(ctx) } func (s *supervisorContainer) GetType() string { return "Supervisor" } func (s *supervisorContainer) Run(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { return s.inner.Run(ctx, input, opts...) } func (s *supervisorContainer) Resume(ctx context.Context, info *adk.ResumeInfo, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { return s.inner.Resume(ctx, info, opts...) } // New creates a supervisor-based multi-agent system with the given configuration. // // In the supervisor pattern, a designated supervisor agent coordinates multiple sub-agents. // The supervisor can delegate tasks to sub-agents and receive their responses, while // sub-agents can only communicate with the supervisor (not with each other directly). // This hierarchical structure enables complex problem-solving through coordinated agent interactions. // // The returned agent is wrapped in an internal container that provides unified tracing. // When used with Runner and callbacks, all agents within the supervisor structure will // share the same trace root, making it easy to observe the entire multi-agent execution // as a single logical unit. func New(ctx context.Context, conf *Config) (adk.ResumableAgent, error) { subAgents := make([]adk.Agent, 0, len(conf.SubAgents)) supervisorName := conf.Supervisor.Name(ctx) for _, subAgent := range conf.SubAgents { subAgents = append(subAgents, adk.AgentWithDeterministicTransferTo(ctx, &adk.DeterministicTransferConfig{ Agent: subAgent, ToAgentNames: []string{supervisorName}, })) } inner, err := adk.SetSubAgents(ctx, conf.Supervisor, subAgents) if err != nil { return nil, err } return &supervisorContainer{ name: supervisorName, inner: inner, }, nil } ================================================ FILE: adk/prebuilt/supervisor/supervisor_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package supervisor import ( "context" "fmt" "strings" "sync" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockAdk "github.com/cloudwego/eino/internal/mock/adk" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) // TestNewSupervisor tests the New function func TestNewSupervisor(t *testing.T) { ctx := context.Background() // Create a mock controller ctrl := gomock.NewController(t) defer ctrl.Finish() // Create mock agents supervisorAgent := mockAdk.NewMockAgent(ctrl) subAgent1 := mockAdk.NewMockAgent(ctrl) subAgent2 := mockAdk.NewMockAgent(ctrl) supervisorAgent.EXPECT().Name(gomock.Any()).Return("SupervisorAgent").AnyTimes() supervisorAgent.EXPECT().Description(gomock.Any()).Return("Supervisor agent description").AnyTimes() subAgent1.EXPECT().Name(gomock.Any()).Return("SubAgent1").AnyTimes() subAgent2.EXPECT().Name(gomock.Any()).Return("SubAgent2").AnyTimes() aMsg, tMsg := adk.GenTransferMessages(ctx, "SubAgent1") i, g := adk.NewAsyncIteratorPair[*adk.AgentEvent]() g.Send(adk.EventFromMessage(aMsg, nil, schema.Assistant, "")) event := adk.EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) event.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "SubAgent1"}} g.Send(event) g.Close() supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() subAgent1Msg := schema.AssistantMessage("SubAgent1", nil) g.Send(adk.EventFromMessage(subAgent1Msg, nil, schema.Assistant, "")) g.Close() subAgent1.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) aMsg, tMsg = adk.GenTransferMessages(ctx, "SubAgent2 message") i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() g.Send(adk.EventFromMessage(aMsg, nil, schema.Assistant, "")) event = adk.EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) event.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "SubAgent2"}} g.Send(event) g.Close() supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() subAgent2Msg := schema.AssistantMessage("SubAgent2 message", nil) g.Send(adk.EventFromMessage(subAgent2Msg, nil, schema.Assistant, "")) g.Close() subAgent2.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() finishMsg := schema.AssistantMessage("finish", nil) g.Send(adk.EventFromMessage(finishMsg, nil, schema.Assistant, "")) g.Close() supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) conf := &Config{ Supervisor: supervisorAgent, SubAgents: []adk.Agent{subAgent1, subAgent2}, } multiAgent, err := New(ctx, conf) assert.NoError(t, err) assert.NotNil(t, multiAgent) assert.Equal(t, "SupervisorAgent", multiAgent.Name(ctx)) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: multiAgent}) aIter := runner.Run(ctx, []adk.Message{schema.UserMessage("test")}) // transfer to agent1 event, ok := aIter.Next() assert.True(t, ok) assert.Equal(t, "SupervisorAgent", event.AgentName) assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SupervisorAgent", event.AgentName) assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) assert.Equal(t, "SubAgent1", event.Action.TransferToAgent.DestAgentName) // agent1's output event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SubAgent1", event.AgentName) assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) assert.Equal(t, subAgent1Msg.Content, event.Output.MessageOutput.Message.Content) // transfer back to supervisor event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SubAgent1", event.AgentName) assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SubAgent1", event.AgentName) assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) assert.Equal(t, "SupervisorAgent", event.Action.TransferToAgent.DestAgentName) // transfer to agent2 event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SupervisorAgent", event.AgentName) assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SupervisorAgent", event.AgentName) assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) assert.Equal(t, "SubAgent2", event.Action.TransferToAgent.DestAgentName) // agent1's output event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SubAgent2", event.AgentName) assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) assert.Equal(t, subAgent2Msg.Content, event.Output.MessageOutput.Message.Content) // transfer back to supervisor event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SubAgent2", event.AgentName) assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SubAgent2", event.AgentName) assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) assert.Equal(t, "SupervisorAgent", event.Action.TransferToAgent.DestAgentName) // finish event, ok = aIter.Next() assert.True(t, ok) assert.Equal(t, "SupervisorAgent", event.AgentName) assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) assert.Equal(t, finishMsg.Content, event.Output.MessageOutput.Message.Content) } type approvalInfo struct { ToolName string ArgumentsInJSON string ToolCallID string } func (ai *approvalInfo) String() string { return fmt.Sprintf("tool '%s' interrupted with arguments '%s', waiting for approval", ai.ToolName, ai.ArgumentsInJSON) } type approvalResult struct { Approved bool DisapproveReason *string } func init() { schema.Register[*approvalInfo]() schema.Register[*approvalResult]() } type approvableTool struct { name string t *testing.T } func (m *approvableTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: m.name, Desc: "A tool that requires approval before execution", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "action": {Type: schema.String, Desc: "The action to perform"}, }), }, nil } func (m *approvableTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { wasInterrupted, _, storedArguments := tool.GetInterruptState[string](ctx) if !wasInterrupted { return "", tool.StatefulInterrupt(ctx, &approvalInfo{ ToolName: m.name, ArgumentsInJSON: argumentsInJSON, ToolCallID: compose.GetToolCallID(ctx), }, argumentsInJSON) } isResumeTarget, hasData, data := tool.GetResumeContext[*approvalResult](ctx) if !isResumeTarget { return "", tool.StatefulInterrupt(ctx, &approvalInfo{ ToolName: m.name, ArgumentsInJSON: storedArguments, ToolCallID: compose.GetToolCallID(ctx), }, storedArguments) } if !hasData { return "", fmt.Errorf("tool '%s' resumed with no data", m.name) } if data.Approved { return fmt.Sprintf("Tool '%s' executed successfully with args: %s", m.name, storedArguments), nil } if data.DisapproveReason != nil { return fmt.Sprintf("Tool '%s' disapproved, reason: %s", m.name, *data.DisapproveReason), nil } return fmt.Sprintf("Tool '%s' disapproved", m.name), nil } type checkpointStore struct { data map[string][]byte } func newCheckpointStore() *checkpointStore { return &checkpointStore{data: make(map[string][]byte)} } func (s *checkpointStore) Set(_ context.Context, key string, value []byte) error { s.data[key] = value return nil } func (s *checkpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { v, ok := s.data[key] return v, ok, nil } type namedAgent struct { adk.ResumableAgent name string description string } func (n *namedAgent) Name(_ context.Context) string { return n.name } func (n *namedAgent) Description(_ context.Context) string { return n.description } func TestNestedSupervisorInterruptResume(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() mockOuterSupervisorModel := mockModel.NewMockToolCallingChatModel(ctrl) mockInnerSupervisorModel := mockModel.NewMockToolCallingChatModel(ctrl) mockWorkerModel := mockModel.NewMockToolCallingChatModel(ctrl) paymentTool := &approvableTool{name: "process_payment", t: t} userInput := []adk.Message{schema.UserMessage("Process a payment of $1000")} mockWorkerModel.EXPECT().WithTools(gomock.Any()).Return(mockWorkerModel, nil).AnyTimes() workerToolCallMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_payment_1", Type: "function", Function: schema.FunctionCall{ Name: "process_payment", Arguments: `{"action": "process $1000 payment"}`, }, }, }) mockWorkerModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(workerToolCallMsg, nil).Times(1) workerCompletionMsg := schema.AssistantMessage("Payment processed successfully", nil) mockWorkerModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(workerCompletionMsg, nil).AnyTimes() workerAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "payment_worker", Description: "the agent responsible for processing payments", Instruction: "You are a payment processing worker. Use the process_payment tool to handle payments.", Model: mockWorkerModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{paymentTool}, }, }, }) assert.NoError(t, err) mockInnerSupervisorModel.EXPECT().WithTools(gomock.Any()).Return(mockInnerSupervisorModel, nil).AnyTimes() innerTransferMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "inner_transfer_1", Type: "function", Function: schema.FunctionCall{ Name: "transfer_to_agent", Arguments: `{"agent_name":"payment_worker"}`, }, }, }) mockInnerSupervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(innerTransferMsg, nil).Times(1) innerFinalMsg := schema.AssistantMessage("Payment has been processed and approved.", nil) mockInnerSupervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(innerFinalMsg, nil).AnyTimes() innerSupervisorChatAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "payment_supervisor", Description: "the supervisor agent responsible for payment operations", Instruction: "You are a payment supervisor. Delegate payment tasks to payment_worker.", Model: mockInnerSupervisorModel, Exit: &adk.ExitTool{}, }) assert.NoError(t, err) innerSupervisorAgent, err := New(ctx, &Config{ Supervisor: innerSupervisorChatAgent, SubAgents: []adk.Agent{workerAgent}, }) assert.NoError(t, err) innerSupervisorWrapped := &namedAgent{ ResumableAgent: innerSupervisorAgent, name: "payment_department", description: "the department responsible for all payment-related operations", } mockOuterSupervisorModel.EXPECT().WithTools(gomock.Any()).Return(mockOuterSupervisorModel, nil).AnyTimes() outerTransferMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "outer_transfer_1", Type: "function", Function: schema.FunctionCall{ Name: "transfer_to_agent", Arguments: `{"agent_name":"payment_department"}`, }, }, }) mockOuterSupervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(outerTransferMsg, nil).Times(1) outerFinalMsg := schema.AssistantMessage("The payment request has been fully processed by the payment department.", nil) mockOuterSupervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(outerFinalMsg, nil).AnyTimes() outerSupervisorChatAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "company_coordinator", Description: "the top-level coordinator for company operations", Instruction: "You are the company coordinator. Route payment requests to payment_department.", Model: mockOuterSupervisorModel, Exit: &adk.ExitTool{}, }) assert.NoError(t, err) outerSupervisorAgent, err := New(ctx, &Config{ Supervisor: outerSupervisorChatAgent, SubAgents: []adk.Agent{innerSupervisorWrapped}, }) assert.NoError(t, err) outerSupervisorWrapped := &namedAgent{ ResumableAgent: outerSupervisorAgent, name: "headquarters", description: "the company headquarters that coordinates all departments", } store := newCheckpointStore() runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: outerSupervisorWrapped, CheckPointStore: store, }) t.Log("========================================") t.Log("Starting Nested Supervisor Integration Test (with namedAgent wrappers)") t.Log("Hierarchy: headquarters(wrapper) -> company_coordinator -> payment_department(wrapper) -> payment_supervisor -> payment_worker -> process_payment tool") t.Log("========================================") checkpointID := "test-nested-supervisor-1" iter := runner.Run(ctx, userInput, adk.WithCheckPointID(checkpointID)) var interruptEvent *adk.AgentEvent eventCount := 0 for { event, ok := iter.Next() if !ok { break } eventCount++ if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event t.Log("INTERRUPT DETECTED - Deep interrupt from tool within nested supervisor") break } } if interruptEvent == nil { t.Fatal("Expected an interrupt event from the process_payment tool, but none was received") } assert.NotNil(t, interruptEvent.Action.Interrupted, "Should have interrupt info") assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts, "Should have interrupt contexts") var toolInterruptID string for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { if intCtx.IsRootCause { toolInterruptID = intCtx.ID break } } assert.NotEmpty(t, toolInterruptID, "Should have a root cause interrupt ID") t.Logf("Resuming with approval for interrupt ID: %s", toolInterruptID) resumeIter, err := runner.ResumeWithParams(ctx, checkpointID, &adk.ResumeParams{ Targets: map[string]any{ toolInterruptID: &approvalResult{Approved: true}, }, }) assert.NoError(t, err, "Resume should not error") assert.NotNil(t, resumeIter, "Resume iterator should not be nil") var resumeEvents []*adk.AgentEvent for { event, ok := resumeIter.Next() if !ok { break } resumeEvents = append(resumeEvents, event) } assert.NotEmpty(t, resumeEvents, "Should have resume events after approval") for _, event := range resumeEvents { assert.NoError(t, event.Err, "Resume event should not have error") } var hasToolResponse, hasTransferBack bool for _, event := range resumeEvents { if event.Output != nil && event.Output.MessageOutput != nil { msg := event.Output.MessageOutput.Message if msg != nil && msg.Role == "tool" && strings.Contains(msg.Content, "executed successfully") { hasToolResponse = true } } if event.Action != nil && event.Action.TransferToAgent != nil { if event.Action.TransferToAgent.DestAgentName == "company_coordinator" { hasTransferBack = true } } } assert.True(t, hasToolResponse, "Should have tool response indicating successful payment processing") assert.True(t, hasTransferBack, "Should have transfer back to outer supervisor indicating completion") } func TestSupervisorExit(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() supervisorAgent := mockAdk.NewMockAgent(ctrl) subAgent := mockAdk.NewMockAgent(ctrl) supervisorAgent.EXPECT().Name(gomock.Any()).Return("Supervisor").AnyTimes() supervisorAgent.EXPECT().Description(gomock.Any()).Return("Supervisor description").AnyTimes() subAgent.EXPECT().Name(gomock.Any()).Return("SubAgent").AnyTimes() // 1. Supervisor transfers to SubAgent aMsg, tMsg := adk.GenTransferMessages(ctx, "SubAgent") i1, g1 := adk.NewAsyncIteratorPair[*adk.AgentEvent]() g1.Send(adk.EventFromMessage(aMsg, nil, schema.Assistant, "")) event1 := adk.EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) event1.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "SubAgent"}} g1.Send(event1) g1.Close() supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i1).Times(1) // 2. SubAgent emits Exit action i2, g2 := adk.NewAsyncIteratorPair[*adk.AgentEvent]() exitEvent := &adk.AgentEvent{ AgentName: "SubAgent", Action: &adk.AgentAction{Exit: true}, Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Role: schema.Assistant, Message: schema.AssistantMessage("Exiting...", nil), }, }, } g2.Send(exitEvent) g2.Close() subAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i2).Times(1) conf := &Config{ Supervisor: supervisorAgent, SubAgents: []adk.Agent{subAgent}, } multiAgent, err := New(ctx, conf) assert.NoError(t, err) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: multiAgent}) aIter := runner.Run(ctx, []adk.Message{schema.UserMessage("test")}) // Collect events var events []*adk.AgentEvent for { event, ok := aIter.Next() if !ok { break } events = append(events, event) } foundExit := false foundTransferBack := false for _, e := range events { if e.Action != nil { if e.Action.Exit { foundExit = true } if e.Action.TransferToAgent != nil && e.Action.TransferToAgent.DestAgentName == "Supervisor" { foundTransferBack = true } } } assert.True(t, foundExit, "Should have found Exit action") assert.False(t, foundTransferBack, "Should NOT have found Transfer back to Supervisor after Exit") } func TestNestedSupervisorExit(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() topSupervisor := mockAdk.NewMockAgent(ctrl) midSupervisor := mockAdk.NewMockAgent(ctrl) worker := mockAdk.NewMockAgent(ctrl) topSupervisor.EXPECT().Name(gomock.Any()).Return("TopSupervisor").AnyTimes() topSupervisor.EXPECT().Description(gomock.Any()).Return("Top supervisor description").AnyTimes() midSupervisor.EXPECT().Name(gomock.Any()).Return("MidSupervisor").AnyTimes() midSupervisor.EXPECT().Description(gomock.Any()).Return("Mid supervisor description").AnyTimes() worker.EXPECT().Name(gomock.Any()).Return("Worker").AnyTimes() // 1. TopSupervisor transfers to MidSupervisor aMsg1, tMsg1 := adk.GenTransferMessages(ctx, "MidSupervisor") i1, g1 := adk.NewAsyncIteratorPair[*adk.AgentEvent]() g1.Send(adk.EventFromMessage(aMsg1, nil, schema.Assistant, "")) event1 := adk.EventFromMessage(tMsg1, nil, schema.Tool, tMsg1.ToolName) event1.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "MidSupervisor"}} g1.Send(event1) g1.Close() topSupervisor.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i1).AnyTimes() // 2. MidSupervisor transfers to Worker aMsg2, tMsg2 := adk.GenTransferMessages(ctx, "Worker") i2, g2 := adk.NewAsyncIteratorPair[*adk.AgentEvent]() g2.Send(adk.EventFromMessage(aMsg2, nil, schema.Assistant, "")) event2 := adk.EventFromMessage(tMsg2, nil, schema.Tool, tMsg2.ToolName) event2.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "Worker"}} g2.Send(event2) g2.Close() midSupervisor.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i2).AnyTimes() // 3. Worker emits Exit action i3, g3 := adk.NewAsyncIteratorPair[*adk.AgentEvent]() exitEvent := &adk.AgentEvent{ AgentName: "Worker", Action: &adk.AgentAction{Exit: true}, Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Role: schema.Assistant, Message: schema.AssistantMessage("Worker Exiting...", nil), }, }, } g3.Send(exitEvent) g3.Close() worker.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i3).Times(1) // Build Nested System // Mid System: MidSupervisor -> [Worker] midSystem, err := New(ctx, &Config{ Supervisor: midSupervisor, SubAgents: []adk.Agent{worker}, }) assert.NoError(t, err) // We need to give the midSystem the name "MidSupervisor" so TopSupervisor can find it // supervisor.New returns a ResumableAgent that delegates Name() to the supervisor agent. // So midSystem.Name() should already be "MidSupervisor" because midSupervisor.Name() is "MidSupervisor". // Top System: TopSupervisor -> [midSystem] topSystem, err := New(ctx, &Config{ Supervisor: topSupervisor, SubAgents: []adk.Agent{midSystem}, }) assert.NoError(t, err) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: topSystem}) aIter := runner.Run(ctx, []adk.Message{schema.UserMessage("test nested exit")}) // Collect events var events []*adk.AgentEvent for { event, ok := aIter.Next() if !ok { break } events = append(events, event) } foundExit := false foundTransferBackToMidAfterExit := false foundTransferBackToTopAfterExit := false for _, e := range events { if e.Action != nil { if e.Action.Exit { foundExit = true } if foundExit && e.Action.TransferToAgent != nil { if e.Action.TransferToAgent.DestAgentName == "MidSupervisor" { foundTransferBackToMidAfterExit = true } if e.Action.TransferToAgent.DestAgentName == "TopSupervisor" { foundTransferBackToTopAfterExit = true } } } } assert.True(t, foundExit, "Should have found Exit action") assert.False(t, foundTransferBackToMidAfterExit, "Should NOT have found Transfer back to MidSupervisor after Exit") assert.False(t, foundTransferBackToTopAfterExit, "Should NOT have found Transfer back to TopSupervisor after Exit") } func TestChatModelAgentInternalEventsExit(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() supervisorAgent := mockAdk.NewMockAgent(ctrl) workerModel := mockModel.NewMockToolCallingChatModel(ctrl) innerAgent := mockAdk.NewMockAgent(ctrl) supervisorAgent.EXPECT().Name(gomock.Any()).Return("Supervisor").AnyTimes() supervisorAgent.EXPECT().Description(gomock.Any()).Return("Supervisor description").AnyTimes() innerAgent.EXPECT().Name(gomock.Any()).Return("InnerAgent").AnyTimes() innerAgent.EXPECT().Description(gomock.Any()).Return("Inner Agent Description").AnyTimes() // 1. Supervisor transfers to Worker (only once, then exits when worker transfers back) supervisorRunCount := 0 supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { supervisorRunCount++ iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]() go func() { defer gen.Close() if supervisorRunCount == 1 { aMsg, tMsg := adk.GenTransferMessages(ctx, "Worker") gen.Send(adk.EventFromMessage(aMsg, nil, schema.Assistant, "")) event1 := adk.EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) event1.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "Worker"}} gen.Send(event1) } else { exitEvent := &adk.AgentEvent{ AgentName: "Supervisor", Action: &adk.AgentAction{Exit: true}, Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Role: schema.Assistant, Message: schema.AssistantMessage("Supervisor done", nil), }, }, } gen.Send(exitEvent) } }() return iter }).AnyTimes() // 2. Worker runs, calls AgentTool (InnerAgent) // Mock WorkerModel behavior workerModel.EXPECT().WithTools(gomock.Any()).Return(workerModel, nil).AnyTimes() // 2.1 Worker generates tool call toolCallMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_inner_1", Type: "function", Function: schema.FunctionCall{ Name: "InnerAgent", Arguments: `{"request": "do exit"}`, }, }, }) workerModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(toolCallMsg, nil).Times(1) // 2.2 InnerAgent runs and emits Exit innerAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]() go func() { defer gen.Close() innerExitEvent := &adk.AgentEvent{ AgentName: "InnerAgent", Action: &adk.AgentAction{Exit: true}, RunPath: []adk.RunStep{}, Output: &adk.AgentOutput{ MessageOutput: &adk.MessageVariant{ Role: schema.Assistant, Message: schema.AssistantMessage("Inner Exiting...", nil), }, }, } gen.Send(innerExitEvent) }() return iter }).AnyTimes() // 2.3 Worker receives tool result (empty string or whatever AgentTool returns on exit/interrupt) // AgentTool implementation details: if Exit action is present, it returns whatever output is there. // The Exit action itself is passed as internal event. // 2.4 Worker generates final response finalMsg := schema.AssistantMessage("Worker Finished", nil) workerModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(finalMsg, nil).AnyTimes() // Build Worker Agent agentTool := adk.NewAgentTool(ctx, innerAgent) workerAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "Worker", Description: "Worker Agent", Model: workerModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{agentTool}, }, EmitInternalEvents: true, // Key configuration }, }) assert.NoError(t, err) // Build System sys, err := New(ctx, &Config{ Supervisor: supervisorAgent, SubAgents: []adk.Agent{workerAgent}, }) assert.NoError(t, err) runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: sys}) aIter := runner.Run(ctx, []adk.Message{schema.UserMessage("start")}) // Collect events var events []*adk.AgentEvent for { event, ok := aIter.Next() if !ok { break } events = append(events, event) } foundInnerExit := false foundTransferBack := false for _, e := range events { // Check for InnerAgent exit event (propagated as internal event) if e.AgentName == "InnerAgent" && e.Action != nil && e.Action.Exit { foundInnerExit = true } // Check for transfer back to Supervisor if e.AgentName == "Worker" && e.Action != nil && e.Action.TransferToAgent != nil && e.Action.TransferToAgent.DestAgentName == "Supervisor" { foundTransferBack = true } } assert.True(t, foundInnerExit, "Should have captured InnerAgent Exit event") assert.True(t, foundTransferBack, "Should have found Transfer back to Supervisor (Worker should NOT be considered exited)") } func TestSupervisorContainerUnifiedTracing(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() supervisorModel := mockModel.NewMockToolCallingChatModel(ctrl) subAgentModel := mockModel.NewMockToolCallingChatModel(ctrl) supervisorModel.EXPECT().WithTools(gomock.Any()).Return(supervisorModel, nil).AnyTimes() subAgentModel.EXPECT().WithTools(gomock.Any()).Return(subAgentModel, nil).AnyTimes() transferMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "transfer_1", Type: "function", Function: schema.FunctionCall{ Name: "transfer_to_agent", Arguments: `{"agent_name":"SubAgent"}`, }, }, }) supervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(transferMsg, nil).Times(1) subAgentResponse := schema.AssistantMessage("SubAgent response", nil) subAgentModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(subAgentResponse, nil).Times(1) finalResponse := schema.AssistantMessage("Final response", nil) supervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(finalResponse, nil).Times(1) supervisorAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "SupervisorAgent", Description: "Supervisor agent", Instruction: "You are a supervisor", Model: supervisorModel, }) assert.NoError(t, err) subAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "SubAgent", Description: "Sub agent", Instruction: "You are a sub agent", Model: subAgentModel, }) assert.NoError(t, err) multiAgent, err := New(ctx, &Config{ Supervisor: supervisorAgent, SubAgents: []adk.Agent{subAgent}, }) assert.NoError(t, err) assert.Equal(t, "SupervisorAgent", multiAgent.Name(ctx)) typer, ok := multiAgent.(components.Typer) assert.True(t, ok, "Should implement components.Typer") assert.Equal(t, "Supervisor", typer.GetType()) var mu sync.Mutex var onStartCalls []string var onEndCalls []string handler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != adk.ComponentOfAgent { return ctx } mu.Lock() onStartCalls = append(onStartCalls, info.Name+":"+info.Type) mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != adk.ComponentOfAgent { return ctx } mu.Lock() onEndCalls = append(onEndCalls, info.Name+":"+info.Type) mu.Unlock() if agentOutput := adk.ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: multiAgent}) iter := runner.Query(ctx, "hello", adk.WithCallbacks(handler)) for { _, ok := iter.Next() if !ok { break } } mu.Lock() defer mu.Unlock() assert.NotEmpty(t, onStartCalls, "Should have OnStart calls") assert.Contains(t, onStartCalls, "SupervisorAgent:Supervisor", "Should have supervisor container OnStart with type 'Supervisor'") } type traceContextKey struct{} func TestSupervisorContainerUnifiedTracingOnResume(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() supervisorModel := mockModel.NewMockToolCallingChatModel(ctrl) workerModel := mockModel.NewMockToolCallingChatModel(ctrl) supervisorModel.EXPECT().WithTools(gomock.Any()).Return(supervisorModel, nil).AnyTimes() workerModel.EXPECT().WithTools(gomock.Any()).Return(workerModel, nil).AnyTimes() paymentTool := &approvableTool{name: "process_payment", t: t} workerToolCallMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call_payment_1", Type: "function", Function: schema.FunctionCall{ Name: "process_payment", Arguments: `{"action": "process $1000 payment"}`, }, }, }) workerModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(workerToolCallMsg, nil).Times(1) workerCompletionMsg := schema.AssistantMessage("Payment processed successfully", nil) workerModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(workerCompletionMsg, nil).AnyTimes() workerAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "Worker", Description: "Worker agent", Instruction: "You are a worker", Model: workerModel, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{paymentTool}, }, }, }) assert.NoError(t, err) transferMsg := schema.AssistantMessage("", []schema.ToolCall{ { ID: "transfer_1", Type: "function", Function: schema.FunctionCall{ Name: "transfer_to_agent", Arguments: `{"agent_name":"Worker"}`, }, }, }) supervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(transferMsg, nil).Times(1) finalResponse := schema.AssistantMessage("Final response", nil) supervisorModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(finalResponse, nil).AnyTimes() supervisorAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: "SupervisorAgent", Description: "Supervisor agent", Instruction: "You are a supervisor", Model: supervisorModel, Exit: &adk.ExitTool{}, }) assert.NoError(t, err) multiAgent, err := New(ctx, &Config{ Supervisor: supervisorAgent, SubAgents: []adk.Agent{workerAgent}, }) assert.NoError(t, err) store := newCheckpointStore() runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: multiAgent, CheckPointStore: store, }) var mu sync.Mutex var runOnStartCalls []string var resumeOnStartCalls []string var resumeParentTraceIDs []string runHandler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != adk.ComponentOfAgent { return ctx } mu.Lock() runOnStartCalls = append(runOnStartCalls, info.Name+":"+info.Type) mu.Unlock() return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != adk.ComponentOfAgent { return ctx } if agentOutput := adk.ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() checkpointID := "test-unified-tracing-resume" iter := runner.Run(ctx, []adk.Message{schema.UserMessage("Process payment")}, adk.WithCallbacks(runHandler), adk.WithCheckPointID(checkpointID)) var interruptEvent *adk.AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event break } } assert.NotNil(t, interruptEvent, "Should have interrupt event") var toolInterruptID string for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { if intCtx.IsRootCause { toolInterruptID = intCtx.ID break } } assert.NotEmpty(t, toolInterruptID, "Should have a root cause interrupt ID") mu.Lock() t.Logf("Run OnStart calls: %v", runOnStartCalls) assert.Contains(t, runOnStartCalls, "SupervisorAgent:Supervisor", "Run should have supervisor container OnStart") mu.Unlock() resumeHandler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component != adk.ComponentOfAgent { return ctx } mu.Lock() resumeOnStartCalls = append(resumeOnStartCalls, info.Name+":"+info.Type) parentID, _ := ctx.Value(traceContextKey{}).(string) resumeParentTraceIDs = append(resumeParentTraceIDs, info.Name+":parent="+parentID) mu.Unlock() if info.Type == "Supervisor" { return context.WithValue(ctx, traceContextKey{}, "supervisor-trace-id") } return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if info.Component != adk.ComponentOfAgent { return ctx } if agentOutput := adk.ConvAgentCallbackOutput(output); agentOutput != nil && agentOutput.Events != nil { go func() { for { _, ok := agentOutput.Events.Next() if !ok { break } } }() } return ctx }). Build() resumeIter, err := runner.ResumeWithParams(ctx, checkpointID, &adk.ResumeParams{ Targets: map[string]any{ toolInterruptID: &approvalResult{Approved: true}, }, }, adk.WithCallbacks(resumeHandler)) assert.NoError(t, err) for { event, ok := resumeIter.Next() if !ok { break } assert.NoError(t, event.Err) } mu.Lock() defer mu.Unlock() t.Logf("Resume OnStart calls: %v", resumeOnStartCalls) t.Logf("Resume parent trace IDs: %v", resumeParentTraceIDs) assert.NotEmpty(t, resumeOnStartCalls, "Should have OnStart calls during resume") assert.Contains(t, resumeOnStartCalls, "SupervisorAgent:Supervisor", "Resume should have supervisor container OnStart with type 'Supervisor'") foundInnerSupervisorWithParent := false for _, entry := range resumeParentTraceIDs { if strings.Contains(entry, "SupervisorAgent") && !strings.Contains(entry, "parent=supervisor-trace-id") && entry != "SupervisorAgent:parent=" { if strings.Contains(resumeOnStartCalls[0], "Supervisor") { continue } } if strings.Contains(entry, "parent=supervisor-trace-id") { foundInnerSupervisorWithParent = true } } assert.True(t, foundInnerSupervisorWithParent, "Inner agents should have parent trace from Supervisor container during Resume. Got: %v", resumeParentTraceIDs) } ================================================ FILE: adk/react.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "context" "encoding/gob" "errors" "io" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // ErrExceedMaxIterations indicates the agent reached the maximum iterations limit. var ErrExceedMaxIterations = errors.New("exceeds max iterations") // State holds agent runtime state including messages and user-extensible storage. // // Deprecated: This type will be unexported in v1.0.0. Use ChatModelAgentState // in HandlerMiddleware and AgentMiddleware callbacks instead. Direct use of // compose.ProcessState[*State] is discouraged and will stop working in v1.0.0; // use the handler APIs instead. type State struct { Messages []Message Extra map[string]any // Internal fields below - do not access directly. // Kept exported for backward compatibility with existing checkpoints. HasReturnDirectly bool ReturnDirectlyToolCallID string ToolGenActions map[string]*AgentAction AgentName string RemainingIterations int ReturnDirectlyEvent *AgentEvent RetryAttempt int } const ( stateGobNameV07 = "_eino_adk_react_state" // stateGobNameV080 is a v0.8.0-v0.8.3-only alias used after byte-patching // raw checkpoint bytes in preprocessADKCheckpoint. // It must stay the same byte length as stateGobNameV07 so the length-prefixed // gob string in the stream remains valid. stateGobNameV080 = "_eino_adk_state_v080_" ) func init() { // Checkpoint compatibility notes: // - ADK/compose checkpoints are gob-encoded and may store state behind `any`, so gob relies on // an on-wire type name to choose a local Go type. // - Gob allows only one local Go type per name, and it treats "struct wire" and "GobEncoder wire" // as incompatible even if the name matches. // // This file maintains 2 epochs of *State decoding: // - v0.7.* and current: "_eino_adk_react_state" + struct wire → decode into *State directly. // State's exported fields are a superset of v0.7, so gob handles missing fields gracefully. // - v0.8.0-v0.8.3: "_eino_adk_react_state" + GobEncoder wire → byte-patched to stateGobNameV080, // decode into stateV080 and migrate. schema.RegisterName[*State](stateGobNameV07) schema.RegisterName[*stateV080](stateGobNameV080) // the following two lines of registration mainly for backward compatibility // when decoding checkpoints created by v0.8.0 - v0.8.3 gob.Register(&AgentEvent{}) gob.Register(int(0)) } func (s *State) getReturnDirectlyEvent() *AgentEvent { return s.ReturnDirectlyEvent } func (s *State) setReturnDirectlyEvent(event *AgentEvent) { s.ReturnDirectlyEvent = event } func (s *State) getRetryAttempt() int { return s.RetryAttempt } func (s *State) setRetryAttempt(attempt int) { s.RetryAttempt = attempt } func (s *State) getReturnDirectlyToolCallID() string { return s.ReturnDirectlyToolCallID } func (s *State) setReturnDirectlyToolCallID(id string) { s.ReturnDirectlyToolCallID = id s.HasReturnDirectly = id != "" } func (s *State) getToolGenActions() map[string]*AgentAction { return s.ToolGenActions } func (s *State) setToolGenAction(key string, action *AgentAction) { if s.ToolGenActions == nil { s.ToolGenActions = make(map[string]*AgentAction) } s.ToolGenActions[key] = action } func (s *State) popToolGenAction(key string) *AgentAction { if s.ToolGenActions == nil { return nil } action := s.ToolGenActions[key] delete(s.ToolGenActions, key) return action } func (s *State) getRemainingIterations() int { return s.RemainingIterations } func (s *State) setRemainingIterations(iterations int) { s.RemainingIterations = iterations } func (s *State) decrementRemainingIterations() { current := s.getRemainingIterations() s.RemainingIterations = current - 1 } // stateV080 handles the v0.8.0-v0.8.3 checkpoint format. // In those versions, *State implemented GobEncoder and was registered under // "_eino_adk_react_state". GobEncode serialized a stateSerialization struct // into opaque bytes. This type's GobDecode reads that format. // It is registered under "_eino_adk_state_v080_" — a same-length alias used // only after byte-patching the checkpoint data in preprocessADKCheckpoint. type stateV080 struct { Messages []Message HasReturnDirectly bool ReturnDirectlyToolCallID string ToolGenActions map[string]*AgentAction AgentName string RemainingIterations int RetryAttempt int ReturnDirectlyEvent *AgentEvent Extra map[string]any Internals map[string]any } // stateV080Serialization is the on-wire format that v0.8.0-v0.8.3 GobEncode produced. // It is only used by stateV080.GobDecode to parse those legacy opaque bytes. type stateV080Serialization stateV080 func (sc *stateV080) GobDecode(b []byte) error { ss := &stateV080Serialization{} if err := gob.NewDecoder(bytes.NewReader(b)).Decode(ss); err != nil { return err } sc.Messages = ss.Messages sc.HasReturnDirectly = ss.HasReturnDirectly sc.ReturnDirectlyToolCallID = ss.ReturnDirectlyToolCallID sc.ToolGenActions = ss.ToolGenActions sc.AgentName = ss.AgentName sc.RemainingIterations = ss.RemainingIterations sc.Extra = ss.Extra sc.Internals = ss.Internals return nil } // stateV080ToState converts a legacy *stateV080 (v0.8.0-v0.8.3) to a current *State. func stateV080ToState(sc *stateV080) *State { s := &State{ Messages: sc.Messages, HasReturnDirectly: sc.HasReturnDirectly, ReturnDirectlyToolCallID: sc.ReturnDirectlyToolCallID, ToolGenActions: sc.ToolGenActions, AgentName: sc.AgentName, RemainingIterations: sc.RemainingIterations, Extra: sc.Extra, } if sc.ReturnDirectlyToolCallID != "" { s.setReturnDirectlyToolCallID(sc.ReturnDirectlyToolCallID) } if sc.Internals != nil && s.RetryAttempt == 0 { if v, ok := sc.Internals["_retryAttempt"].(int); ok { s.RetryAttempt = v } } if sc.Internals != nil && s.ReturnDirectlyEvent == nil { if v, ok := sc.Internals["_returnDirectlyEvent"].(*AgentEvent); ok { s.ReturnDirectlyEvent = v } } return s } // SendToolGenAction attaches an AgentAction to the next tool event emitted for the // current tool execution. // // Where/when to use: // - Invoke within a tool's Run (Invokable/Streamable) implementation to include // an action alongside that tool's output event. // - The action is scoped by the current tool call context: if a ToolCallID is // available, it is used as the key to support concurrent calls of the same // tool with different parameters; otherwise, the provided toolName is used. // - The stored action is ephemeral and will be popped and attached to the tool // event when the tool finishes (including streaming completion). // // Limitation: // - This function is intended for use within ChatModelAgent runs only. It relies // on ChatModelAgent's internal State to store and pop actions, which is not // available in other agent types. func SendToolGenAction(ctx context.Context, toolName string, action *AgentAction) error { key := toolName toolCallID := compose.GetToolCallID(ctx) if len(toolCallID) > 0 { key = toolCallID } return compose.ProcessState(ctx, func(ctx context.Context, st *State) error { st.setToolGenAction(key, action) return nil }) } type reactInput struct { messages []Message } type reactConfig struct { // model is the chat model used by the react graph. // Tools are configured via model.WithTools call option, not the WithTools method. model model.BaseChatModel toolsConfig *compose.ToolsNodeConfig modelWrapperConf *modelWrapperConfig toolsReturnDirectly map[string]bool agentName string maxIterations int } func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { toolInfos := make([]*schema.ToolInfo, 0, len(config.Tools)) for _, t := range config.Tools { tl, err := t.Info(ctx) if err != nil { return nil, err } toolInfos = append(toolInfos, tl) } return toolInfos, nil } type reactGraph = *compose.Graph[*reactInput, Message] type sToolNodeOutput = *schema.StreamReader[[]Message] type sGraphOutput = MessageStream func getReturnDirectlyToolCallID(ctx context.Context) (string, bool) { var toolCallID string handler := func(_ context.Context, st *State) error { toolCallID = st.getReturnDirectlyToolCallID() return nil } _ = compose.ProcessState(ctx, handler) return toolCallID, toolCallID != "" } func genReactState(config *reactConfig) func(ctx context.Context) *State { return func(ctx context.Context) *State { st := &State{ AgentName: config.agentName, } maxIter := 20 if config.maxIterations > 0 { maxIter = config.maxIterations } st.setRemainingIterations(maxIter) return st } } func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { const ( initNode_ = "Init" chatModel_ = "ChatModel" toolNode_ = "ToolNode" ) g := compose.NewGraph[*reactInput, Message](compose.WithGenLocalState(genReactState(config))) initLambda := func(ctx context.Context, input *reactInput) ([]Message, error) { return input.messages, nil } _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(initLambda), compose.WithNodeName(initNode_)) var wrappedModel model.BaseChatModel = config.model if config.modelWrapperConf != nil { wrappedModel = buildModelWrappers(config.model, config.modelWrapperConf) } toolsNode, err := compose.NewToolNode(ctx, config.toolsConfig) if err != nil { return nil, err } modelPreHandle := func(ctx context.Context, input []Message, st *State) ([]Message, error) { if st.getRemainingIterations() <= 0 { return nil, ErrExceedMaxIterations } st.decrementRemainingIterations() return input, nil } _ = g.AddChatModelNode(chatModel_, wrappedModel, compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(chatModel_)) toolPreHandle := func(ctx context.Context, _ Message, st *State) (Message, error) { input := st.Messages[len(st.Messages)-1] returnDirectly := config.toolsReturnDirectly if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil && len(execCtx.runtimeReturnDirectly) > 0 { returnDirectly = execCtx.runtimeReturnDirectly } if len(returnDirectly) > 0 { for i := range input.ToolCalls { toolName := input.ToolCalls[i].Function.Name if _, ok := returnDirectly[toolName]; ok { st.setReturnDirectlyToolCallID(input.ToolCalls[i].ID) } } } return input, nil } toolPostHandle := func(ctx context.Context, out *schema.StreamReader[[]*schema.Message], st *State) (*schema.StreamReader[[]*schema.Message], error) { if event := st.getReturnDirectlyEvent(); event != nil { getChatModelAgentExecCtx(ctx).send(event) st.setReturnDirectlyEvent(nil) } return out, nil } _ = g.AddToolsNode(toolNode_, toolsNode, compose.WithStatePreHandler(toolPreHandle), compose.WithStreamStatePostHandler(toolPostHandle), compose.WithNodeName(toolNode_)) _ = g.AddEdge(compose.START, initNode_) _ = g.AddEdge(initNode_, chatModel_) toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { defer sMsg.Close() for { chunk, err_ := sMsg.Recv() if err_ != nil { if err_ == io.EOF { return compose.END, nil } return "", err_ } if len(chunk.ToolCalls) > 0 { return toolNode_, nil } } } branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, toolNode_: true}) _ = g.AddBranch(chatModel_, branch) if len(config.toolsReturnDirectly) > 0 { const ( toolNodeToEndConverter = "ToolNodeToEndConverter" ) cvt := func(ctx context.Context, sToolCallMessages sToolNodeOutput) (sGraphOutput, error) { id, _ := getReturnDirectlyToolCallID(ctx) return schema.StreamReaderWithConvert(sToolCallMessages, func(in []Message) (Message, error) { for _, chunk := range in { if chunk != nil && chunk.ToolCallID == id { return chunk, nil } } return nil, schema.ErrNoValue }), nil } _ = g.AddLambdaNode(toolNodeToEndConverter, compose.TransformableLambda(cvt), compose.WithNodeName(toolNodeToEndConverter)) _ = g.AddEdge(toolNodeToEndConverter, compose.END) checkReturnDirect := func(ctx context.Context, sToolCallMessages sToolNodeOutput) (string, error) { _, ok := getReturnDirectlyToolCallID(ctx) if ok { return toolNodeToEndConverter, nil } return chatModel_, nil } branch = compose.NewStreamGraphBranch(checkReturnDirect, map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) _ = g.AddBranch(toolNode_, branch) } else { _ = g.AddEdge(toolNode_, chatModel_) } return g, nil } ================================================ FILE: adk/react_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "context" "encoding/gob" "errors" "fmt" "io" "math/rand" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) type testModelWrapper struct { inner model.ToolCallingChatModel } func TestStateCompatConversions_V080(t *testing.T) { t.Run("stateV080GobDecodeAndToState", func(t *testing.T) { ss := &stateV080Serialization{ ReturnDirectlyToolCallID: "tcid", RemainingIterations: 2, Internals: map[string]any{ "_retryAttempt": 9, "_returnDirectlyEvent": &AgentEvent{AgentName: "agent"}, }, } var buf bytes.Buffer assert.NoError(t, gob.NewEncoder(&buf).Encode(ss)) var legacy stateV080 assert.NoError(t, legacy.GobDecode(buf.Bytes())) s := stateV080ToState(&legacy) assert.Equal(t, "tcid", s.ReturnDirectlyToolCallID) assert.True(t, s.HasReturnDirectly) assert.Equal(t, 2, s.RemainingIterations) assert.Equal(t, 9, s.RetryAttempt) assert.NotNil(t, s.ReturnDirectlyEvent) assert.Equal(t, "agent", s.ReturnDirectlyEvent.AgentName) }) } func TestStateGetToolGenActions(t *testing.T) { st := &State{ ToolGenActions: map[string]*AgentAction{ "k": {}, }, } assert.NotNil(t, st.getToolGenActions()) assert.Contains(t, st.getToolGenActions(), "k") } func (w *testModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { return (&stateModelWrapper{inner: w.inner, original: w.inner}).Generate(ctx, input, opts...) } func (w *testModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { return (&stateModelWrapper{inner: w.inner, original: w.inner}).Stream(ctx, input, opts...) } func (w *testModelWrapper) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { newInner, err := w.inner.WithTools(tools) if err != nil { return nil, err } return &testModelWrapper{inner: newInner}, nil } // TestReact tests the newReact function with different scenarios func TestReact(t *testing.T) { // Basic test for newReact function t.Run("Invoke", func(t *testing.T) { ctx := context.Background() // Create a fake tool for testing fakeTool := &fakeToolForTest{ tarCount: 3, } info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) (Message, error) { times++ if times <= 2 { return schema.AssistantMessage("hello test", []schema.ToolCall{ { ID: randStrForTest(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStrForTest()), }, }, }), nil } return schema.AssistantMessage("bye", nil), nil }).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create a reactConfig config := &reactConfig{ model: &testModelWrapper{inner: cm}, toolsConfig: &compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, toolsReturnDirectly: map[string]bool{}, } graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) compiled, err := graph.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", }, }}) assert.NoError(t, err) assert.NotNil(t, result) }) // Test with toolsReturnDirectly t.Run("ToolsReturnDirectly", func(t *testing.T) { ctx := context.Background() // Create a fake tool for testing fakeTool := &fakeToolForTest{ tarCount: 3, } info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) (Message, error) { times++ if times <= 2 { return schema.AssistantMessage("hello test", []schema.ToolCall{ { ID: randStrForTest(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStrForTest()), }, }, }), nil } return schema.AssistantMessage("bye", nil), nil }).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create a reactConfig with toolsReturnDirectly config := &reactConfig{ model: &testModelWrapper{inner: cm}, toolsConfig: &compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, toolsReturnDirectly: map[string]bool{info.Name: true}, } graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) compiled, err := graph.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message when tool returns directly result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", }, }}) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, result.Role, schema.Tool) }) // Test streaming functionality t.Run("Stream", func(t *testing.T) { ctx := context.Background() // Create a fake tool for testing fakeTool := &fakeToolForTest{ tarCount: 3, } fakeStreamTool := &fakeStreamToolForTest{ tarCount: 3, } // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model times := 0 cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) ( MessageStream, error) { sr, sw := schema.Pipe[Message](1) defer sw.Close() info, _ := fakeTool.Info(ctx) streamInfo, _ := fakeStreamTool.Info(ctx) times++ if times <= 1 { sw.Send(schema.AssistantMessage("hello test", []schema.ToolCall{ { ID: randStrForTest(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStrForTest()), }, }, }), nil) return sr, nil } else if times == 2 { sw.Send(schema.AssistantMessage("hello stream", []schema.ToolCall{ { ID: randStrForTest(), Function: schema.FunctionCall{ Name: streamInfo.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStrForTest()), }, }, }), nil) return sr, nil } sw.Send(schema.AssistantMessage("bye", nil), nil) return sr, nil }).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create a reactConfig config := &reactConfig{ model: &testModelWrapper{inner: cm}, toolsConfig: &compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, }, toolsReturnDirectly: map[string]bool{}, } graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) compiled, err := graph.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, compiled) // Test streaming with a user message outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", }, }}) assert.NoError(t, err) assert.NotNil(t, outStream) defer outStream.Close() msgs := make([]Message, 0) for { msg, err_ := outStream.Recv() if err_ != nil { if errors.Is(err_, io.EOF) { break } t.Fatal(err_) } msgs = append(msgs, msg) } assert.NotEmpty(t, msgs) }) // Test streaming with toolsReturnDirectly t.Run("StreamWithToolsReturnDirectly", func(t *testing.T) { ctx := context.Background() // Create a fake tool for testing fakeTool := &fakeToolForTest{ tarCount: 3, } fakeStreamTool := &fakeStreamToolForTest{ tarCount: 3, } // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model times := 0 cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) ( MessageStream, error) { sr, sw := schema.Pipe[Message](1) defer sw.Close() info, _ := fakeTool.Info(ctx) streamInfo, _ := fakeStreamTool.Info(ctx) times++ if times <= 1 { sw.Send(schema.AssistantMessage("hello test", []schema.ToolCall{ { ID: randStrForTest(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStrForTest()), }, }, }), nil) return sr, nil } else if times == 2 { sw.Send(schema.AssistantMessage("hello stream", []schema.ToolCall{ { ID: randStrForTest(), Function: schema.FunctionCall{ Name: streamInfo.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStrForTest()), }, }, }), nil) return sr, nil } sw.Send(schema.AssistantMessage("bye", nil), nil) return sr, nil }).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() streamInfo, err := fakeStreamTool.Info(ctx) assert.NoError(t, err) // Create a reactConfig with toolsReturnDirectly config := &reactConfig{ model: &testModelWrapper{inner: cm}, toolsConfig: &compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, }, toolsReturnDirectly: map[string]bool{streamInfo.Name: true}, } graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) compiled, err := graph.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, compiled) // Reset times counter times = 0 // Test streaming with a user message when tool returns directly outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", }, }}) assert.NoError(t, err) assert.NotNil(t, outStream) msgs := make([]Message, 0) for { msg, err_ := outStream.Recv() if err_ != nil { if errors.Is(err_, io.EOF) { break } t.Fatal(err) } assert.Equal(t, msg.Role, schema.Tool) msgs = append(msgs, msg) } outStream.Close() assert.NotEmpty(t, msgs) }) t.Run("MaxIterations", func(t *testing.T) { ctx := context.Background() // Create a fake tool for testing fakeTool := &fakeToolForTest{ tarCount: 3, } info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Create a mock chat model ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // Set up expectations for the mock model times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) (Message, error) { times++ if times <= 5 { return schema.AssistantMessage("hello test", []schema.ToolCall{ { ID: randStrForTest(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStrForTest()), }, }, }), nil } return schema.AssistantMessage("bye", nil), nil }).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // don't exceed max iterations config := &reactConfig{ model: &testModelWrapper{inner: cm}, toolsConfig: &compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, toolsReturnDirectly: map[string]bool{}, maxIterations: 6, } graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) compiled, err := graph.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", }, }}) assert.NoError(t, err) assert.Equal(t, result.Content, "bye") // reset chat model times counter times = 0 // exceed max iterations config = &reactConfig{ model: &testModelWrapper{inner: cm}, toolsConfig: &compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, toolsReturnDirectly: map[string]bool{}, maxIterations: 5, } graph, err = newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) compiled, err = graph.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message result, err = compiled.Invoke(ctx, &reactInput{messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", }, }}) assert.Error(t, err) t.Logf("actual error: %v", err.Error()) assert.ErrorIs(t, err, ErrExceedMaxIterations) assert.Contains(t, err.Error(), ErrExceedMaxIterations.Error()) }) } // Helper types and functions for testing type fakeStreamToolForTest struct { tarCount int curCount int } func (t *fakeStreamToolForTest) StreamableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) ( *schema.StreamReader[string], error) { p := &fakeToolInputForTest{} err := sonic.UnmarshalString(argumentsInJSON, p) if err != nil { return nil, err } if t.curCount >= t.tarCount { s := schema.StreamReaderFromArray([]string{`{"say": "bye"}`}) return s, nil } t.curCount++ s := schema.StreamReaderFromArray([]string{fmt.Sprintf(`{"say": "hello %v"}`, p.Name)}) return s, nil } type fakeToolForTest struct { tarCount int curCount int } func (t *fakeToolForTest) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "test_tool", Desc: "test tool for unit testing", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Desc: "user name for testing", Required: true, Type: schema.String, }, }), }, nil } func (t *fakeStreamToolForTest) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "test_stream_tool", Desc: "test stream tool for unit testing", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Desc: "user name for testing", Required: true, Type: schema.String, }, }), }, nil } func (t *fakeToolForTest) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { p := &fakeToolInputForTest{} err := sonic.UnmarshalString(argumentsInJSON, p) if err != nil { return "", err } if t.curCount >= t.tarCount { return `{"say": "bye"}`, nil } t.curCount++ return fmt.Sprintf(`{"say": "hello %v"}`, p.Name), nil } type fakeToolInputForTest struct { Name string `json:"name"` } func randStrForTest() string { seeds := []rune("test seed") b := make([]rune, 8) for i := range b { b[i] = seeds[rand.Intn(len(seeds))] } return string(b) } ================================================ FILE: adk/retry_chatmodel.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "fmt" "io" "log" "math/rand" "time" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) var ( // ErrExceedMaxRetries is returned when the maximum number of retries has been exceeded. // Use errors.Is to check if an error is due to max retries being exceeded: // // if errors.Is(err, adk.ErrExceedMaxRetries) { // // handle max retries exceeded // } // // Use errors.As to extract the underlying RetryExhaustedError for the last error details: // // var retryErr *adk.RetryExhaustedError // if errors.As(err, &retryErr) { // fmt.Printf("last error was: %v\n", retryErr.LastErr) // } ErrExceedMaxRetries = errors.New("exceeds max retries") ) // RetryExhaustedError is returned when all retry attempts have been exhausted. // It wraps the last error that occurred during retry attempts. type RetryExhaustedError struct { LastErr error TotalRetries int } func (e *RetryExhaustedError) Error() string { if e.LastErr != nil { return fmt.Sprintf("exceeds max retries: last error: %v", e.LastErr) } return "exceeds max retries" } func (e *RetryExhaustedError) Unwrap() error { return ErrExceedMaxRetries } // WillRetryError is emitted when a retryable error occurs and a retry will be attempted. // It allows end-users to observe retry events in real-time via AgentEvent. // // Field design rationale: // - ErrStr (exported): Stores the error message string for Gob serialization during checkpointing. // This ensures the error message is preserved after checkpoint restore. // - err (unexported): Stores the original error for Unwrap() support at runtime. // This field is intentionally unexported because Gob serialization would fail for unregistered // concrete error types. Since end-users only need the original error when the AgentEvent first // occurs (not after restoring from checkpoint), skipping serialization is acceptable. // After checkpoint restore, err will be nil and Unwrap() returns nil. type WillRetryError struct { ErrStr string RetryAttempt int err error } func (e *WillRetryError) Error() string { return e.ErrStr } func (e *WillRetryError) Unwrap() error { return e.err } func init() { schema.RegisterName[*WillRetryError]("eino_adk_chatmodel_will_retry_error") } // ModelRetryConfig configures retry behavior for the ChatModel node. // It defines how the agent should handle transient failures when calling the ChatModel. type ModelRetryConfig struct { // MaxRetries specifies the maximum number of retry attempts. // A value of 0 means no retries will be attempted. // A value of 3 means up to 3 retry attempts (4 total calls including the initial attempt). MaxRetries int // IsRetryAble is a function that determines whether an error should trigger a retry. // If nil, all errors are considered retry-able. // Return true if the error is transient and the operation should be retried. // Return false if the error is permanent and should be propagated immediately. IsRetryAble func(ctx context.Context, err error) bool // BackoffFunc calculates the delay before the next retry attempt. // The attempt parameter starts at 1 for the first retry. // If nil, a default exponential backoff with jitter is used: // base delay 100ms, exponentially increasing up to 10s max, // with random jitter (0-50% of delay) to prevent thundering herd. BackoffFunc func(ctx context.Context, attempt int) time.Duration } func defaultIsRetryAble(_ context.Context, err error) bool { return err != nil } func defaultBackoff(_ context.Context, attempt int) time.Duration { baseDelay := 100 * time.Millisecond maxDelay := 10 * time.Second if attempt <= 0 { return baseDelay } if attempt > 7 { return maxDelay + time.Duration(rand.Int63n(int64(maxDelay/2))) } delay := baseDelay * time.Duration(1< maxDelay { delay = maxDelay } jitter := time.Duration(rand.Int63n(int64(delay / 2))) return delay + jitter } func genErrWrapper(ctx context.Context, maxRetries, attempt int, isRetryAbleFunc func(ctx context.Context, err error) bool) func(error) error { return func(err error) error { isRetryAble := isRetryAbleFunc == nil || isRetryAbleFunc(ctx, err) hasRetriesLeft := attempt < maxRetries if isRetryAble && hasRetriesLeft { return &WillRetryError{ErrStr: err.Error(), RetryAttempt: attempt, err: err} } return err } } func consumeStreamForError(stream *schema.StreamReader[*schema.Message]) error { defer stream.Close() for { _, err := stream.Recv() if err == io.EOF { return nil } if err != nil { return err } } } // retryModelWrapper wraps a BaseChatModel with retry logic. // This is used inside the model wrapper chain, positioned between eventSenderModelWrapper // and stateModelWrapper, so that retry only affects the inner chain (event sending, user wrappers, // callback injection) without re-running state management (BeforeModelRewriteState/AfterModelRewriteState). type retryModelWrapper struct { inner model.BaseChatModel config *ModelRetryConfig } func newRetryModelWrapper(inner model.BaseChatModel, config *ModelRetryConfig) *retryModelWrapper { return &retryModelWrapper{inner: inner, config: config} } func (r *retryModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { isRetryAble := r.config.IsRetryAble if isRetryAble == nil { isRetryAble = defaultIsRetryAble } backoffFunc := r.config.BackoffFunc if backoffFunc == nil { backoffFunc = defaultBackoff } var lastErr error for attempt := 0; attempt <= r.config.MaxRetries; attempt++ { out, err := r.inner.Generate(ctx, input, opts...) if err == nil { return out, nil } if !isRetryAble(ctx, err) { return nil, err } lastErr = err if attempt < r.config.MaxRetries { log.Printf("retrying ChatModel.Generate (attempt %d/%d): %v", attempt+1, r.config.MaxRetries, err) time.Sleep(backoffFunc(ctx, attempt+1)) } } return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} } func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) ( *schema.StreamReader[*schema.Message], error) { isRetryAble := r.config.IsRetryAble if isRetryAble == nil { isRetryAble = defaultIsRetryAble } backoffFunc := r.config.BackoffFunc if backoffFunc == nil { backoffFunc = defaultBackoff } defer func() { _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { st.setRetryAttempt(0) return nil }) }() var lastErr error for attempt := 0; attempt <= r.config.MaxRetries; attempt++ { _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { st.setRetryAttempt(attempt) return nil }) stream, err := r.inner.Stream(ctx, input, opts...) if err != nil { if !isRetryAble(ctx, err) { return nil, err } lastErr = err if attempt < r.config.MaxRetries { log.Printf("retrying ChatModel.Stream (attempt %d/%d): %v", attempt+1, r.config.MaxRetries, err) time.Sleep(backoffFunc(ctx, attempt+1)) } continue } copies := stream.Copy(2) checkCopy := copies[0] returnCopy := copies[1] streamErr := consumeStreamForError(checkCopy) if streamErr == nil { return returnCopy, nil } returnCopy.Close() if !isRetryAble(ctx, streamErr) { return nil, streamErr } lastErr = streamErr if attempt < r.config.MaxRetries { log.Printf("retrying ChatModel.Stream (attempt %d/%d): %v", attempt+1, r.config.MaxRetries, streamErr) time.Sleep(backoffFunc(ctx, attempt+1)) } } return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} } ================================================ FILE: adk/runctx.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "context" "encoding/gob" "fmt" "sort" "sync" "time" "github.com/cloudwego/eino/schema" ) // runSession CheckpointSchema: persisted via serialization.RunCtx (gob). type runSession struct { Values map[string]any valuesMtx *sync.Mutex Events []*agentEventWrapper LaneEvents *laneEvents mtx sync.Mutex } // laneEvents CheckpointSchema: persisted via serialization.RunCtx (gob). type laneEvents struct { Events []*agentEventWrapper Parent *laneEvents } // agentEventWrapper CheckpointSchema: persisted via serialization.RunCtx (gob). type agentEventWrapper struct { *AgentEvent mu sync.Mutex concatenatedMessage Message // TS is the timestamp (in nanoseconds) when this event was created. // It is primarily used by the laneEvents mechanism to order events // from different agents in a multi-agent flow. TS int64 // StreamErr stores the error message if the MessageStream contained an error. // This field guards against multiple calls to getMessageFromWrappedEvent // when the stream has already been consumed and errored. // Normally when StreamErr happens, the Agent will return with the error, // unless retry is configured for the agent generating this stream, in which case // this StreamErr will be of type WillRetryError (indicating retry is pending). StreamErr error } type otherAgentEventWrapperForEncode agentEventWrapper func (a *agentEventWrapper) GobEncode() ([]byte, error) { if a.concatenatedMessage != nil && a.Output != nil && a.Output.MessageOutput != nil && a.Output.MessageOutput.IsStreaming { a.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]Message{a.concatenatedMessage}) } buf := &bytes.Buffer{} err := gob.NewEncoder(buf).Encode((*otherAgentEventWrapperForEncode)(a)) if err != nil { return nil, fmt.Errorf("failed to gob encode agent event wrapper: %w", err) } return buf.Bytes(), nil } func (a *agentEventWrapper) GobDecode(b []byte) error { return gob.NewDecoder(bytes.NewReader(b)).Decode((*otherAgentEventWrapperForEncode)(a)) } func newRunSession() *runSession { return &runSession{ Values: make(map[string]any), valuesMtx: &sync.Mutex{}, } } // GetSessionValues returns all session key-value pairs for the current run. func GetSessionValues(ctx context.Context) map[string]any { session := getSession(ctx) if session == nil { return map[string]any{} } return session.getValues() } // AddSessionValue sets a single session key-value pair for the current run. func AddSessionValue(ctx context.Context, key string, value any) { session := getSession(ctx) if session == nil { return } session.addValue(key, value) } // AddSessionValues sets multiple session key-value pairs for the current run. func AddSessionValues(ctx context.Context, kvs map[string]any) { session := getSession(ctx) if session == nil { return } session.addValues(kvs) } // GetSessionValue retrieves a session value by key and reports whether it exists. func GetSessionValue(ctx context.Context, key string) (any, bool) { session := getSession(ctx) if session == nil { return nil, false } return session.getValue(key) } func (rs *runSession) addEvent(event *AgentEvent) { wrapper := &agentEventWrapper{AgentEvent: event, TS: time.Now().UnixNano()} // If LaneEvents is not nil, we are in a parallel lane. // Append to the lane's local event slice (lock-free). if rs.LaneEvents != nil { rs.LaneEvents.Events = append(rs.LaneEvents.Events, wrapper) return } // Otherwise, we are on the main path. Append to the shared Events slice (with lock). rs.mtx.Lock() rs.Events = append(rs.Events, wrapper) rs.mtx.Unlock() } func (rs *runSession) getEvents() []*agentEventWrapper { // If there are no in-flight lane events, we can return the main slice directly. if rs.LaneEvents == nil { rs.mtx.Lock() events := rs.Events rs.mtx.Unlock() return events } // If there are in-flight events, we must construct the full view. // First, get the committed history from the main Events slice. rs.mtx.Lock() committedEvents := make([]*agentEventWrapper, len(rs.Events)) copy(committedEvents, rs.Events) rs.mtx.Unlock() // Then, assemble the in-flight events by traversing the linked list. // Reading the .Parent pointer is safe without a lock because the parent of a lane is immutable after creation. var laneSlices [][]*agentEventWrapper totalLaneSize := 0 for lane := rs.LaneEvents; lane != nil; lane = lane.Parent { if len(lane.Events) > 0 { laneSlices = append(laneSlices, lane.Events) totalLaneSize += len(lane.Events) } } // Combine committed and in-flight history. finalEvents := make([]*agentEventWrapper, 0, len(committedEvents)+totalLaneSize) finalEvents = append(finalEvents, committedEvents...) for i := len(laneSlices) - 1; i >= 0; i-- { finalEvents = append(finalEvents, laneSlices[i]...) } return finalEvents } func (rs *runSession) getValues() map[string]any { rs.valuesMtx.Lock() values := make(map[string]any, len(rs.Values)) for k, v := range rs.Values { values[k] = v } rs.valuesMtx.Unlock() return values } func (rs *runSession) addValue(key string, value any) { rs.valuesMtx.Lock() rs.Values[key] = value rs.valuesMtx.Unlock() } func (rs *runSession) addValues(kvs map[string]any) { rs.valuesMtx.Lock() for k, v := range kvs { rs.Values[k] = v } rs.valuesMtx.Unlock() } func (rs *runSession) getValue(key string) (any, bool) { rs.valuesMtx.Lock() value, ok := rs.Values[key] rs.valuesMtx.Unlock() return value, ok } type runContext struct { RootInput *AgentInput RunPath []RunStep Session *runSession } func (rc *runContext) isRoot() bool { return len(rc.RunPath) == 1 } func (rc *runContext) deepCopy() *runContext { copied := &runContext{ RootInput: rc.RootInput, RunPath: make([]RunStep, len(rc.RunPath)), Session: rc.Session, } copy(copied.RunPath, rc.RunPath) return copied } type runCtxKey struct{} func getRunCtx(ctx context.Context) *runContext { runCtx, ok := ctx.Value(runCtxKey{}).(*runContext) if !ok { return nil } return runCtx } func setRunCtx(ctx context.Context, runCtx *runContext) context.Context { return context.WithValue(ctx, runCtxKey{}, runCtx) } func initRunCtx(ctx context.Context, agentName string, input *AgentInput) (context.Context, *runContext) { runCtx := getRunCtx(ctx) if runCtx != nil { runCtx = runCtx.deepCopy() } else { runCtx = &runContext{Session: newRunSession()} } runCtx.RunPath = append(runCtx.RunPath, RunStep{agentName: agentName}) if runCtx.isRoot() && input != nil { runCtx.RootInput = input } return setRunCtx(ctx, runCtx), runCtx } func joinRunCtxs(parentCtx context.Context, childCtxs ...context.Context) { switch len(childCtxs) { case 0: return case 1: // Optimization for the common case of a single branch. newEvents := unwindLaneEvents(childCtxs...) commitEvents(parentCtx, newEvents) return } // 1. Collect all new events from the leaf nodes of each context's lane. newEvents := unwindLaneEvents(childCtxs...) // 2. Sort the collected events by their creation timestamp for chronological order. sort.Slice(newEvents, func(i, j int) bool { return newEvents[i].TS < newEvents[j].TS }) // 3. Commit the sorted events to the parent. commitEvents(parentCtx, newEvents) } // commitEvents appends a slice of new events to the correct parent lane or main event log. func commitEvents(ctx context.Context, newEvents []*agentEventWrapper) { runCtx := getRunCtx(ctx) if runCtx == nil || runCtx.Session == nil { // Should not happen, but handle defensively. return } // If the context we are committing to is itself a lane, append to its event slice. if runCtx.Session.LaneEvents != nil { runCtx.Session.LaneEvents.Events = append(runCtx.Session.LaneEvents.Events, newEvents...) } else { // Otherwise, commit to the main, shared Events slice with a lock. runCtx.Session.mtx.Lock() runCtx.Session.Events = append(runCtx.Session.Events, newEvents...) runCtx.Session.mtx.Unlock() } } // unwindLaneEvents traverses the LaneEvents of the given contexts and collects // all events from the leaf nodes. func unwindLaneEvents(ctxs ...context.Context) []*agentEventWrapper { var allNewEvents []*agentEventWrapper for _, ctx := range ctxs { runCtx := getRunCtx(ctx) if runCtx != nil && runCtx.Session != nil && runCtx.Session.LaneEvents != nil { allNewEvents = append(allNewEvents, runCtx.Session.LaneEvents.Events...) } } return allNewEvents } func forkRunCtx(ctx context.Context) context.Context { parentRunCtx := getRunCtx(ctx) if parentRunCtx == nil || parentRunCtx.Session == nil { // Should not happen in a parallel workflow, but handle defensively. return ctx } // Create a new session for the child lane by manually copying the parent's session fields. // This is crucial to ensure a new mutex is created and that the LaneEvents pointer is unique. childSession := &runSession{ Events: parentRunCtx.Session.Events, // Share the committed history Values: parentRunCtx.Session.Values, // Share the values map valuesMtx: parentRunCtx.Session.valuesMtx, } // Fork the lane events within the new session struct. childSession.LaneEvents = &laneEvents{ Parent: parentRunCtx.Session.LaneEvents, Events: make([]*agentEventWrapper, 0), } // Create a new runContext for the child lane, pointing to the new session. childRunCtx := &runContext{ RootInput: parentRunCtx.RootInput, RunPath: make([]RunStep, len(parentRunCtx.RunPath)), Session: childSession, } copy(childRunCtx.RunPath, parentRunCtx.RunPath) return setRunCtx(ctx, childRunCtx) } // updateRunPathOnly creates a new context with an updated RunPath, but does NOT modify the Address. // This is used by sequential workflows to accumulate execution history for LLM context, // without incorrectly chaining the static addresses of peer agents. func updateRunPathOnly(ctx context.Context, agentNames ...string) context.Context { runCtx := getRunCtx(ctx) if runCtx == nil { // This should not happen in a sequential workflow context, but handle defensively. runCtx = &runContext{Session: newRunSession()} } else { runCtx = runCtx.deepCopy() } for _, agentName := range agentNames { runCtx.RunPath = append(runCtx.RunPath, RunStep{agentName: agentName}) } return setRunCtx(ctx, runCtx) } // ClearRunCtx clears the run context of the multi-agents. This is particularly useful // when a customized agent with a multi-agents inside it is set as a subagent of another // multi-agents. In such cases, it's not expected to pass the outside run context to the // inside multi-agents, so this function helps isolate the contexts properly. func ClearRunCtx(ctx context.Context) context.Context { return context.WithValue(ctx, runCtxKey{}, nil) } func ctxWithNewRunCtx(ctx context.Context, input *AgentInput, sharedParentSession bool) context.Context { var session *runSession if sharedParentSession { if parentSession := getSession(ctx); parentSession != nil { session = &runSession{ Values: parentSession.Values, valuesMtx: parentSession.valuesMtx, } } } if session == nil { session = newRunSession() } return setRunCtx(ctx, &runContext{Session: session, RootInput: input}) } func getSession(ctx context.Context) *runSession { runCtx := getRunCtx(ctx) if runCtx != nil { return runCtx.Session } return nil } ================================================ FILE: adk/runctx_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestSessionValues(t *testing.T) { // Test Case 1: Basic AddSessionValues and GetSessionValues t.Run("BasicSessionValues", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add values to the session values := map[string]any{ "key1": "value1", "key2": 42, "key3": true, } AddSessionValues(ctx, values) // Get all values from the session retrievedValues := GetSessionValues(ctx) // Verify the values were added correctly assert.Equal(t, "value1", retrievedValues["key1"]) assert.Equal(t, 42, retrievedValues["key2"]) assert.Equal(t, true, retrievedValues["key3"]) assert.Len(t, retrievedValues, 3) }) // Test Case 2: AddSessionValues with empty context t.Run("AddSessionValuesEmptyContext", func(t *testing.T) { ctx := context.Background() // Add values to a context without a run session values := map[string]any{ "key1": "value1", } AddSessionValues(ctx, values) // Get values should return empty map retrievedValues := GetSessionValues(ctx) assert.Empty(t, retrievedValues) }) // Test Case 3: GetSessionValues with empty context t.Run("GetSessionValuesEmptyContext", func(t *testing.T) { ctx := context.Background() // Get values from a context without a run session retrievedValues := GetSessionValues(ctx) assert.Empty(t, retrievedValues) }) // Test Case 4: AddSessionValues with nil values t.Run("AddSessionValuesNilValues", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add nil values map AddSessionValues(ctx, nil) // Get values should still be empty retrievedValues := GetSessionValues(ctx) assert.Empty(t, retrievedValues) }) // Test Case 5: AddSessionValues with empty values t.Run("AddSessionValuesEmptyValues", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add empty values map AddSessionValues(ctx, map[string]any{}) // Get values should be empty retrievedValues := GetSessionValues(ctx) assert.Empty(t, retrievedValues) }) // Test Case 6: AddSessionValues with complex data types t.Run("AddSessionValuesComplexTypes", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add complex values to the session values := map[string]any{ "string": "hello world", "int": 123, "float": 45.67, "bool": true, "slice": []string{"a", "b", "c"}, "map": map[string]int{"x": 1, "y": 2}, "struct": struct{ Name string }{Name: "test"}, } AddSessionValues(ctx, values) // Get all values from the session retrievedValues := GetSessionValues(ctx) // Verify the complex values were added correctly assert.Equal(t, "hello world", retrievedValues["string"]) assert.Equal(t, 123, retrievedValues["int"]) assert.Equal(t, 45.67, retrievedValues["float"]) assert.Equal(t, true, retrievedValues["bool"]) assert.Equal(t, []string{"a", "b", "c"}, retrievedValues["slice"]) assert.Equal(t, map[string]int{"x": 1, "y": 2}, retrievedValues["map"]) assert.Equal(t, struct{ Name string }{Name: "test"}, retrievedValues["struct"]) assert.Len(t, retrievedValues, 7) }) // Test Case 7: AddSessionValues overwrites existing values t.Run("AddSessionValuesOverwrite", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add initial values initialValues := map[string]any{ "key1": "initial1", "key2": "initial2", } AddSessionValues(ctx, initialValues) // Add values that overwrite some keys overwriteValues := map[string]any{ "key1": "overwritten1", "key3": "new3", } AddSessionValues(ctx, overwriteValues) // Get all values from the session retrievedValues := GetSessionValues(ctx) // Verify the values were overwritten correctly assert.Equal(t, "overwritten1", retrievedValues["key1"]) // overwritten assert.Equal(t, "initial2", retrievedValues["key2"]) // unchanged assert.Equal(t, "new3", retrievedValues["key3"]) // new assert.Len(t, retrievedValues, 3) }) // Test Case 8: Concurrent access to session values t.Run("ConcurrentSessionValues", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add initial values initialValues := map[string]any{ "counter": 0, } AddSessionValues(ctx, initialValues) // Simulate concurrent access done := make(chan bool) // Goroutine 1: Add values go func() { for i := 0; i < 100; i++ { values := map[string]any{ "goroutine1": i, } AddSessionValues(ctx, values) } done <- true }() // Goroutine 2: Add different values go func() { for i := 0; i < 100; i++ { values := map[string]any{ "goroutine2": i, } AddSessionValues(ctx, values) } done <- true }() // Wait for both goroutines to complete <-done <-done // Verify that both values were set (last write wins) retrievedValues := GetSessionValues(ctx) assert.Equal(t, 0, retrievedValues["counter"]) assert.Equal(t, 99, retrievedValues["goroutine1"]) assert.Equal(t, 99, retrievedValues["goroutine2"]) }) // Test Case 9: GetSessionValue individual value t.Run("GetSessionValueIndividual", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add values to the session values := map[string]any{ "key1": "value1", "key2": 42, } AddSessionValues(ctx, values) // Get individual values value1, exists1 := GetSessionValue(ctx, "key1") value2, exists2 := GetSessionValue(ctx, "key2") value3, exists3 := GetSessionValue(ctx, "nonexistent") // Verify individual values assert.True(t, exists1) assert.Equal(t, "value1", value1) assert.True(t, exists2) assert.Equal(t, 42, value2) assert.False(t, exists3) assert.Nil(t, value3) }) // Test Case 10: AddSessionValue individual value t.Run("AddSessionValueIndividual", func(t *testing.T) { ctx := context.Background() // Create a context with a run session session := newRunSession() runCtx := &runContext{Session: session} ctx = setRunCtx(ctx, runCtx) // Add individual values AddSessionValue(ctx, "key1", "value1") AddSessionValue(ctx, "key2", 42) // Get all values retrievedValues := GetSessionValues(ctx) // Verify the values were added correctly assert.Equal(t, "value1", retrievedValues["key1"]) assert.Equal(t, 42, retrievedValues["key2"]) assert.Len(t, retrievedValues, 2) }) // Test Case 11: AddSessionValue with empty context t.Run("AddSessionValueEmptyContext", func(t *testing.T) { ctx := context.Background() // Add individual value to a context without a run session AddSessionValue(ctx, "key1", "value1") // Get individual value should return false value, exists := GetSessionValue(ctx, "key1") assert.False(t, exists) assert.Nil(t, value) // Get all values should return empty map retrievedValues := GetSessionValues(ctx) assert.Empty(t, retrievedValues) }) // Test Case 12: Integration with run context initialization t.Run("IntegrationWithRunContext", func(t *testing.T) { ctx := context.Background() // Initialize a run context with an agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("test input"), }, } ctx, runCtx := initRunCtx(ctx, "test-agent", input) // Verify the run context was created assert.NotNil(t, runCtx) assert.NotNil(t, runCtx.Session) // Add values to the session values := map[string]any{ "integration_key": "integration_value", } AddSessionValues(ctx, values) // Get values from the session retrievedValues := GetSessionValues(ctx) assert.Equal(t, "integration_value", retrievedValues["integration_key"]) // Verify the run path was set correctly assert.Len(t, runCtx.RunPath, 1) assert.Equal(t, "test-agent", runCtx.RunPath[0].agentName) }) } func TestForkJoinRunCtx(t *testing.T) { // Helper to create a named event newEvent := func(name string) *AgentEvent { // Add a small sleep to ensure timestamps are distinct time.Sleep(1 * time.Millisecond) return &AgentEvent{AgentName: name} } // Helper to get event names from a slice of wrappers getEventNames := func(wrappers []*agentEventWrapper) []string { names := make([]string, len(wrappers)) for i, w := range wrappers { names[i] = w.AgentName } return names } // 1. Setup: Create an initial runContext for the main execution path. mainCtx, mainRunCtx := initRunCtx(context.Background(), "Main", nil) // 2. Run Agent A eventA := newEvent("A") mainRunCtx.Session.addEvent(eventA) assert.Equal(t, []string{"A"}, getEventNames(mainRunCtx.Session.getEvents()), "After A") // 3. Fork for Par(B, C) ctxB := forkRunCtx(mainCtx) ctxC := forkRunCtx(mainCtx) // Assertions for Fork runCtxB := getRunCtx(ctxB) runCtxC := getRunCtx(ctxC) assert.NotSame(t, mainRunCtx.Session, runCtxB.Session, "Session B should be a new struct") assert.NotSame(t, mainRunCtx.Session, runCtxC.Session, "Session C should be a new struct") assert.NotSame(t, runCtxB.Session, runCtxC.Session, "Sessions B and C should be different") assert.Nil(t, mainRunCtx.Session.LaneEvents, "Main session should have no lane events yet") assert.NotNil(t, runCtxB.Session.LaneEvents, "Session B should have lane events") assert.NotNil(t, runCtxC.Session.LaneEvents, "Session C should have lane events") assert.Nil(t, runCtxB.Session.LaneEvents.Parent, "Lane B's parent should be the main (nil) lane") assert.Nil(t, runCtxC.Session.LaneEvents.Parent, "Lane C's parent should be the main (nil) lane") // 4. Run Agent B eventB := newEvent("B") runCtxB.Session.addEvent(eventB) assert.Equal(t, []string{"A", "B"}, getEventNames(runCtxB.Session.getEvents()), "After B") // 5. Run Agent C (and Nested Fork for Par(D, E)) eventC1 := newEvent("C1") runCtxC.Session.addEvent(eventC1) assert.Equal(t, []string{"A", "C1"}, getEventNames(runCtxC.Session.getEvents()), "After C1") ctxD := forkRunCtx(ctxC) ctxE := forkRunCtx(ctxC) // Assertions for Nested Fork runCtxD := getRunCtx(ctxD) runCtxE := getRunCtx(ctxE) assert.NotNil(t, runCtxD.Session.LaneEvents.Parent, "Lane D's parent should be Lane C") assert.Same(t, runCtxC.Session.LaneEvents, runCtxD.Session.LaneEvents.Parent, "Lane D's parent must be Lane C's node") assert.Same(t, runCtxC.Session.LaneEvents, runCtxE.Session.LaneEvents.Parent, "Lane E's parent must be Lane C's node") // 6. Run Agents D and E eventD := newEvent("D") runCtxD.Session.addEvent(eventD) eventE := newEvent("E") runCtxE.Session.addEvent(eventE) assert.Equal(t, []string{"A", "C1", "D"}, getEventNames(runCtxD.Session.getEvents()), "After D") assert.Equal(t, []string{"A", "C1", "E"}, getEventNames(runCtxE.Session.getEvents()), "After E") // 7. Join Par(D, E) joinRunCtxs(ctxC, ctxD, ctxE) // Assertions for Nested Join // The events should now be committed to Lane C's event slice. assert.Equal(t, []string{"A", "C1", "D", "E"}, getEventNames(runCtxC.Session.getEvents()), "After joining D and E") // 8. Join Par(B, C) joinRunCtxs(mainCtx, ctxB, ctxC) // Assertions for Top-Level Join // The events should now be committed to the main session's Events slice. assert.Equal(t, []string{"A", "B", "C1", "D", "E"}, getEventNames(mainRunCtx.Session.getEvents()), "After joining B and C") // 9. Run Agent F eventF := newEvent("F") mainRunCtx.Session.addEvent(eventF) assert.Equal(t, []string{"A", "B", "C1", "D", "E", "F"}, getEventNames(mainRunCtx.Session.getEvents()), "After F") } ================================================ FILE: adk/runner.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "fmt" "runtime/debug" "sync" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) // Runner is the primary entry point for executing an Agent. // It manages the agent's lifecycle, including starting, resuming, and checkpointing. type Runner struct { // a is the agent to be executed. a Agent // enableStreaming dictates whether the execution should be in streaming mode. enableStreaming bool // store is the checkpoint store used to persist agent state upon interruption. // If nil, checkpointing is disabled. store CheckPointStore } type CheckPointStore = core.CheckPointStore type RunnerConfig struct { Agent Agent EnableStreaming bool CheckPointStore CheckPointStore } // ResumeParams contains all parameters needed to resume an execution. // This struct provides an extensible way to pass resume parameters without // requiring breaking changes to method signatures. type ResumeParams struct { // Targets contains the addresses of components to be resumed as keys, // with their corresponding resume data as values Targets map[string]any // Future extensible fields can be added here without breaking changes } // NewRunner creates a Runner that executes an Agent with optional streaming // and checkpoint persistence. func NewRunner(_ context.Context, conf RunnerConfig) *Runner { return &Runner{ enableStreaming: conf.EnableStreaming, a: conf.Agent, store: conf.CheckPointStore, } } // Run starts a new execution of the agent with a given set of messages. // It returns an iterator that yields agent events as they occur. // If the Runner was configured with a CheckPointStore, it will automatically save the agent's state // upon interruption. func (r *Runner) Run(ctx context.Context, messages []Message, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { o := getCommonOptions(nil, opts...) fa := toFlowAgent(ctx, r.a) input := &AgentInput{ Messages: messages, EnableStreaming: r.enableStreaming, } ctx = ctxWithNewRunCtx(ctx, input, o.sharedParentSession) AddSessionValues(ctx, o.sessionValues) iter := fa.Run(ctx, input, opts...) if r.store == nil { return iter } niter, gen := NewAsyncIteratorPair[*AgentEvent]() go r.handleIter(ctx, iter, gen, o.checkPointID) return niter } // Query is a convenience method that starts a new execution with a single user query string. func (r *Runner) Query(ctx context.Context, query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) } // Resume continues an interrupted execution from a checkpoint, using an "Implicit Resume All" strategy. // This method is best for simpler use cases where the act of resuming implies that all previously // interrupted points should proceed without specific data. // // When using this method, all interrupted agents will receive `isResumeFlow = false` when they // call `GetResumeContext`, as no specific agent was targeted. This is suitable for the "Simple Confirmation" // pattern where an agent only needs to know `wasInterrupted` is true to continue. func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( *AsyncIterator[*AgentEvent], error) { return r.resume(ctx, checkPointID, nil, opts...) } // ResumeWithParams continues an interrupted execution from a checkpoint with specific parameters. // This is the most common and powerful way to resume, allowing you to target specific interrupt points // (identified by their address/ID) and provide them with data. // // The params.Targets map should contain the addresses of the components to be resumed as keys. These addresses // can point to any interruptible component in the entire execution graph, including ADK agents, compose // graph nodes, or tools. The value can be the resume data for that component, or `nil` if no data is needed. // // When using this method: // - Components whose addresses are in the params.Targets map will receive `isResumeFlow = true` when they // call `GetResumeContext`. // - Interrupted components whose addresses are NOT in the params.Targets map must decide how to proceed: // -- "Leaf" components (the actual root causes of the original interrupt) MUST re-interrupt themselves // to preserve their state. // -- "Composite" agents (like SequentialAgent or ChatModelAgent) should generally proceed with their // execution. They act as conduits, allowing the resume signal to flow to their children. They will // naturally re-interrupt if one of their interrupted children re-interrupts, as they receive the // new `CompositeInterrupt` signal from them. func (r *Runner) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { return r.resume(ctx, checkPointID, params.Targets, opts...) } // resume is the internal implementation for both Resume and ResumeWithParams. func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map[string]any, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { if r.store == nil { return nil, fmt.Errorf("failed to resume: store is nil") } ctx, runCtx, resumeInfo, err := r.loadCheckPoint(ctx, checkPointID) if err != nil { return nil, fmt.Errorf("failed to load from checkpoint: %w", err) } o := getCommonOptions(nil, opts...) if o.sharedParentSession { parentSession := getSession(ctx) if parentSession != nil { runCtx.Session.Values = parentSession.Values runCtx.Session.valuesMtx = parentSession.valuesMtx } } if runCtx.Session.valuesMtx == nil { runCtx.Session.valuesMtx = &sync.Mutex{} } if runCtx.Session.Values == nil { runCtx.Session.Values = make(map[string]any) } ctx = setRunCtx(ctx, runCtx) AddSessionValues(ctx, o.sessionValues) if len(resumeData) > 0 { ctx = core.BatchResumeWithData(ctx, resumeData) } fa := toFlowAgent(ctx, r.a) aIter := fa.Resume(ctx, resumeInfo, opts...) if r.store == nil { return aIter, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() go r.handleIter(ctx, aIter, gen, &checkPointID) return niter, nil } func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent], checkPointID *string) { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) gen.Send(&AgentEvent{Err: e}) } gen.Close() }() var ( interruptSignal *core.InterruptSignal legacyData any ) for { event, ok := aIter.Next() if !ok { break } if event.Action != nil && event.Action.internalInterrupted != nil { if interruptSignal != nil { // even if multiple interrupt happens, they should be merged into one // action by CompositeInterrupt, so here in Runner we must assume at most // one interrupt action happens panic("multiple interrupt actions should not happen in Runner") } interruptSignal = event.Action.internalInterrupted interruptContexts := core.ToInterruptContexts(interruptSignal, allowedAddressSegmentTypes) event = &AgentEvent{ AgentName: event.AgentName, RunPath: event.RunPath, Output: event.Output, Action: &AgentAction{ Interrupted: &InterruptInfo{ Data: event.Action.Interrupted.Data, InterruptContexts: interruptContexts, }, internalInterrupted: interruptSignal, }, } legacyData = event.Action.Interrupted.Data if checkPointID != nil { // save checkpoint first before sending interrupt event, // so when end-user receives interrupt event, they can resume from this checkpoint err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{ Data: legacyData, }, interruptSignal) if err != nil { gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint: %w", err)}) } } } gen.Send(event) } } ================================================ FILE: adk/runner_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) // mockRunnerAgent is a simple implementation of the Agent interface for testing Runner type mockRunnerAgent struct { name string description string responses []*AgentEvent // Track calls to verify correct parameters were passed callCount int lastInput *AgentInput enableStreaming bool } func (a *mockRunnerAgent) Name(_ context.Context) string { return a.name } func (a *mockRunnerAgent) Description(_ context.Context) string { return a.description } func (a *mockRunnerAgent) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { // Record the call details for verification a.callCount++ a.lastInput = input a.enableStreaming = input.EnableStreaming iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { defer generator.Close() for _, event := range a.responses { generator.Send(event) // If the event has an Exit action, stop sending events if event.Action != nil && event.Action.Exit { break } } }() return iterator } func newMockRunnerAgent(name, description string, responses []*AgentEvent) *mockRunnerAgent { return &mockRunnerAgent{ name: name, description: description, responses: responses, } } func TestNewRunner(t *testing.T) { ctx := context.Background() config := RunnerConfig{} runner := NewRunner(ctx, config) // Verify that a non-nil runner is returned assert.NotNil(t, runner) } func TestRunner_Run(t *testing.T) { ctx := context.Background() // Create a mock agent with predefined responses mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ { AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response from test agent", nil), Role: schema.Assistant, }, }}, }) // Create a runner runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_}) // Create test messages msgs := []Message{ schema.UserMessage("Hello, agent!"), } // Test Run method without streaming iterator := runner.Run(ctx, msgs) // Verify that the agent's Run method was called with the correct parameters assert.Equal(t, 1, mockAgent_.callCount) assert.Equal(t, msgs, mockAgent_.lastInput.Messages) assert.False(t, mockAgent_.enableStreaming) // Verify that we can get the expected response from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Equal(t, "TestAgent", event.AgentName) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) assert.NotNil(t, event.Output.MessageOutput.Message) assert.Equal(t, "Response from test agent", event.Output.MessageOutput.Message.Content) // Verify that the iterator is now closed _, ok = iterator.Next() assert.False(t, ok) } func TestRunner_Run_WithStreaming(t *testing.T) { ctx := context.Background() // Create a mock agent with predefined responses mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ { AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, Message: nil, MessageStream: schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("Streaming response", nil)}), Role: schema.Assistant, }, }}, }) // Create a runner runner := NewRunner(ctx, RunnerConfig{EnableStreaming: true, Agent: mockAgent_}) // Create test messages msgs := []Message{ schema.UserMessage("Hello, agent!"), } // Test Run method with streaming enabled iterator := runner.Run(ctx, msgs) // Verify that the agent's Run method was called with the correct parameters assert.Equal(t, 1, mockAgent_.callCount) assert.Equal(t, msgs, mockAgent_.lastInput.Messages) assert.True(t, mockAgent_.enableStreaming) // Verify that we can get the expected response from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Equal(t, "TestAgent", event.AgentName) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) assert.True(t, event.Output.MessageOutput.IsStreaming) // Verify that the iterator is now closed _, ok = iterator.Next() assert.False(t, ok) } func TestRunner_Query(t *testing.T) { ctx := context.Background() // Create a mock agent with predefined responses mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ { AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response to query", nil), Role: schema.Assistant, }, }}, }) // Create a runner runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_}) // Test Query method iterator := runner.Query(ctx, "Test query") // Verify that the agent's Run method was called with the correct parameters assert.Equal(t, 1, mockAgent_.callCount) assert.Equal(t, 1, len(mockAgent_.lastInput.Messages)) assert.Equal(t, "Test query", mockAgent_.lastInput.Messages[0].Content) assert.False(t, mockAgent_.enableStreaming) // Verify that we can get the expected response from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Equal(t, "TestAgent", event.AgentName) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) assert.NotNil(t, event.Output.MessageOutput.Message) assert.Equal(t, "Response to query", event.Output.MessageOutput.Message.Content) // Verify that the iterator is now closed _, ok = iterator.Next() assert.False(t, ok) } func TestRunner_Query_WithStreaming(t *testing.T) { ctx := context.Background() // Create a mock agent with predefined responses mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ { AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, Message: nil, MessageStream: schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("Streaming query response", nil)}), Role: schema.Assistant, }, }}, }) // Create a runner runner := NewRunner(ctx, RunnerConfig{EnableStreaming: true, Agent: mockAgent_}) // Test Query method with streaming enabled iterator := runner.Query(ctx, "Test query") // Verify that the agent's Run method was called with the correct parameters assert.Equal(t, 1, mockAgent_.callCount) assert.Equal(t, 1, len(mockAgent_.lastInput.Messages)) assert.Equal(t, "Test query", mockAgent_.lastInput.Messages[0].Content) assert.True(t, mockAgent_.enableStreaming) // Verify that we can get the expected response from the iterator event, ok := iterator.Next() assert.True(t, ok) assert.Equal(t, "TestAgent", event.AgentName) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) assert.True(t, event.Output.MessageOutput.IsStreaming) // Verify that the iterator is now closed _, ok = iterator.Next() assert.False(t, ok) } ================================================ FILE: adk/utils.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "io" "strings" "github.com/google/uuid" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/schema" ) type AsyncIterator[T any] struct { ch *internal.UnboundedChan[T] } func (ai *AsyncIterator[T]) Next() (T, bool) { return ai.ch.Receive() } type AsyncGenerator[T any] struct { ch *internal.UnboundedChan[T] } func (ag *AsyncGenerator[T]) Send(v T) { ag.ch.Send(v) } func (ag *AsyncGenerator[T]) Close() { ag.ch.Close() } // NewAsyncIteratorPair returns a paired async iterator and generator // that share the same underlying channel. func NewAsyncIteratorPair[T any]() (*AsyncIterator[T], *AsyncGenerator[T]) { ch := internal.NewUnboundedChan[T]() return &AsyncIterator[T]{ch}, &AsyncGenerator[T]{ch} } func copyMap[K comparable, V any](m map[K]V) map[K]V { res := make(map[K]V, len(m)) for k, v := range m { res[k] = v } return res } func cloneSlice[T any](s []T) []T { if s == nil { return nil } res := make([]T, len(s)) copy(res, s) return res } func concatInstructions(instructions ...string) string { var sb strings.Builder sb.WriteString(instructions[0]) for i := 1; i < len(instructions); i++ { sb.WriteString("\n\n") sb.WriteString(instructions[i]) } return sb.String() } // GenTransferMessages generates assistant and tool messages to instruct a // transfer-to-agent tool call targeting the destination agent. func GenTransferMessages(_ context.Context, destAgentName string) (Message, Message) { toolCallID := uuid.NewString() tooCall := schema.ToolCall{ID: toolCallID, Function: schema.FunctionCall{Name: TransferToAgentToolName, Arguments: destAgentName}} assistantMessage := schema.AssistantMessage("", []schema.ToolCall{tooCall}) msg := transferToAgentToolOutput(destAgentName) toolMessage := schema.ToolMessage(msg, toolCallID, schema.WithToolName(TransferToAgentToolName)) return assistantMessage, toolMessage } // set automatic close for event's message stream func setAutomaticClose(e *AgentEvent) { if e.Output == nil || e.Output.MessageOutput == nil || !e.Output.MessageOutput.IsStreaming { return } e.Output.MessageOutput.MessageStream.SetAutomaticClose() } // getMessageFromWrappedEvent extracts the message from an AgentEvent. // If the stream contains an error chunk, this function returns (nil, err) and // sets StreamErr to prevent re-consumption. The nil message ensures that // failed stream responses are not included in subsequent agents' context windows. func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { if e.AgentEvent.Output == nil || e.AgentEvent.Output.MessageOutput == nil { return nil, nil } if !e.AgentEvent.Output.MessageOutput.IsStreaming { return e.AgentEvent.Output.MessageOutput.Message, nil } if e.concatenatedMessage != nil { return e.concatenatedMessage, nil } if e.StreamErr != nil { return nil, e.StreamErr } e.mu.Lock() defer e.mu.Unlock() if e.concatenatedMessage != nil { return e.concatenatedMessage, nil } var ( msgs []Message s = e.AgentEvent.Output.MessageOutput.MessageStream ) defer s.Close() for { msg, err := s.Recv() if err != nil { if err == io.EOF { break } e.StreamErr = err // Replace the stream with successfully received messages only (no error at the end). // The error is preserved in StreamErr for users to check. // We intentionally exclude the error from the new stream to ensure gob encoding // compatibility, as the stream may be consumed during serialization. e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) return nil, err } msgs = append(msgs, msg) } if len(msgs) == 0 { return nil, errors.New("no messages in MessageVariant.MessageStream") } if len(msgs) == 1 { e.concatenatedMessage = msgs[0] } else { var err error e.concatenatedMessage, err = schema.ConcatMessages(msgs) if err != nil { e.StreamErr = err e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) return nil, err } } return e.concatenatedMessage, nil } // copyAgentEvent copies an AgentEvent. // If the MessageVariant is streaming, the MessageStream will be copied. // RunPath will be deep copied. // The result of Copy will be a new AgentEvent that is: // - safe to set fields of AgentEvent // - safe to extend RunPath // - safe to receive from MessageStream // NOTE: even if the AgentEvent is copied, it's still not recommended to modify // the Message itself or Chunks of the MessageStream, as they are not copied. // NOTE: if you have CustomizedOutput or CustomizedAction, they are NOT copied. func copyAgentEvent(ae *AgentEvent) *AgentEvent { rp := make([]RunStep, len(ae.RunPath)) copy(rp, ae.RunPath) copied := &AgentEvent{ AgentName: ae.AgentName, RunPath: rp, Action: ae.Action, Err: ae.Err, } if ae.Output == nil { return copied } copied.Output = &AgentOutput{ CustomizedOutput: ae.Output.CustomizedOutput, } mv := ae.Output.MessageOutput if mv == nil { return copied } copied.Output.MessageOutput = &MessageVariant{ IsStreaming: mv.IsStreaming, Role: mv.Role, ToolName: mv.ToolName, } if mv.IsStreaming { sts := ae.Output.MessageOutput.MessageStream.Copy(2) mv.MessageStream = sts[0] copied.Output.MessageOutput.MessageStream = sts[1] } else { copied.Output.MessageOutput.Message = mv.Message } return copied } // GetMessage extracts the Message from an AgentEvent. For streaming output, // it duplicates the stream and concatenates it into a single Message. func GetMessage(e *AgentEvent) (Message, *AgentEvent, error) { if e.Output == nil || e.Output.MessageOutput == nil { return nil, e, nil } msgOutput := e.Output.MessageOutput if msgOutput.IsStreaming { ss := msgOutput.MessageStream.Copy(2) e.Output.MessageOutput.MessageStream = ss[0] msg, err := schema.ConcatMessageStream(ss[1]) return msg, e, err } return msgOutput.Message, e, nil } func genErrorIter(err error) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{Err: err}) generator.Close() return iterator } ================================================ FILE: adk/utils_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "bytes" "encoding/gob" "errors" "fmt" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestAsyncIteratorPair_Basic(t *testing.T) { // Create a new iterator-generator pair iterator, generator := NewAsyncIteratorPair[string]() // Test sending and receiving a value generator.Send("test1") val, ok := iterator.Next() if !ok { t.Error("receive should succeed") } if val != "test1" { t.Errorf("expected 'test1', got '%s'", val) } // Test sending and receiving multiple values generator.Send("test2") generator.Send("test3") val, ok = iterator.Next() if !ok { t.Error("receive should succeed") } if val != "test2" { t.Errorf("expected 'test2', got '%s'", val) } val, ok = iterator.Next() if !ok { t.Error("receive should succeed") } if val != "test3" { t.Errorf("expected 'test3', got '%s'", val) } } func TestAsyncIteratorPair_Close(t *testing.T) { iterator, generator := NewAsyncIteratorPair[int]() // Send some values generator.Send(1) generator.Send(2) // Close the generator generator.Close() // Should still be able to read existing values val, ok := iterator.Next() if !ok { t.Error("receive should succeed") } if val != 1 { t.Errorf("expected 1, got %d", val) } val, ok = iterator.Next() if !ok { t.Error("receive should succeed") } if val != 2 { t.Errorf("expected 2, got %d", val) } // After consuming all values, Next should return false _, ok = iterator.Next() if ok { t.Error("receive from closed, empty channel should return ok=false") } } func TestAsyncIteratorPair_Concurrency(t *testing.T) { iterator, generator := NewAsyncIteratorPair[int]() const numSenders = 5 const numReceivers = 3 const messagesPerSender = 100 var rwg, swg sync.WaitGroup rwg.Add(numReceivers) swg.Add(numSenders) // Start senders for i := 0; i < numSenders; i++ { go func(id int) { defer swg.Done() for j := 0; j < messagesPerSender; j++ { generator.Send(id*messagesPerSender + j) time.Sleep(time.Microsecond) // Small delay to increase concurrency chance } }(i) } // Start receivers received := make([]int, 0, numSenders*messagesPerSender) var mu sync.Mutex for i := 0; i < numReceivers; i++ { go func() { defer rwg.Done() for { val, ok := iterator.Next() if !ok { return } mu.Lock() received = append(received, val) mu.Unlock() } }() } // Wait for senders to finish swg.Wait() generator.Close() // Wait for all goroutines to finish rwg.Wait() // Verify we received all messages if len(received) != numSenders*messagesPerSender { t.Errorf("expected %d messages, got %d", numSenders*messagesPerSender, len(received)) } // Create a map to check for duplicates and missing values receivedMap := make(map[int]bool) for _, val := range received { receivedMap[val] = true } if len(receivedMap) != numSenders*messagesPerSender { t.Error("duplicate or missing messages detected") } } func TestGenErrorIter(t *testing.T) { iter := genErrorIter(fmt.Errorf("test")) e, ok := iter.Next() assert.True(t, ok) assert.Equal(t, "test", e.Err.Error()) _, ok = iter.Next() assert.False(t, ok) } func TestGetMessageFromWrappedEvent_StreamError_MultipleCallsGuard(t *testing.T) { streamErr := errors.New("stream error") sr, sw := schema.Pipe[Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("chunk1", nil), nil) sw.Send(schema.AssistantMessage("chunk2", nil), nil) sw.Send(nil, streamErr) }() wrapper := &agentEventWrapper{ AgentEvent: &AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, MessageStream: sr, }, }, }, } msg1, err1 := getMessageFromWrappedEvent(wrapper) assert.Nil(t, msg1) assert.NotNil(t, err1) assert.Equal(t, "stream error", err1.Error()) assert.NotEmpty(t, wrapper.StreamErr) assert.Equal(t, err1, wrapper.StreamErr) msg2, err2 := getMessageFromWrappedEvent(wrapper) assert.Nil(t, msg2) assert.Equal(t, err1, err2) } func TestGetMessageFromWrappedEvent_StreamSuccess_MultipleCallsCached(t *testing.T) { sr, sw := schema.Pipe[Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("chunk1", nil), nil) sw.Send(schema.AssistantMessage("chunk2", nil), nil) }() wrapper := &agentEventWrapper{ AgentEvent: &AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, MessageStream: sr, }, }, }, } msg1, err1 := getMessageFromWrappedEvent(wrapper) assert.NotNil(t, msg1) assert.Nil(t, err1) assert.Equal(t, "chunk1chunk2", msg1.Content) assert.NotNil(t, wrapper.concatenatedMessage) msg2, err2 := getMessageFromWrappedEvent(wrapper) assert.NotNil(t, msg2) assert.Nil(t, err2) assert.Equal(t, "chunk1chunk2", msg2.Content) assert.Same(t, msg1, msg2) } func TestGetMessageFromWrappedEvent_StreamError_PartialMessagesPreserved(t *testing.T) { streamErr := errors.New("stream error at chunk3") sr, sw := schema.Pipe[Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("chunk1", nil), nil) sw.Send(schema.AssistantMessage("chunk2", nil), nil) sw.Send(nil, streamErr) }() wrapper := &agentEventWrapper{ AgentEvent: &AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, MessageStream: sr, }, }, }, } _, err := getMessageFromWrappedEvent(wrapper) assert.NotNil(t, err) assert.Equal(t, streamErr, wrapper.StreamErr) newStream := wrapper.AgentEvent.Output.MessageOutput.MessageStream assert.NotNil(t, newStream) var msgs []Message for { msg, err := newStream.Recv() if err != nil { break } msgs = append(msgs, msg) } assert.Equal(t, 2, len(msgs)) assert.Equal(t, "chunk1", msgs[0].Content) assert.Equal(t, "chunk2", msgs[1].Content) } func TestAgentEventWrapper_GobEncoding_WithWillRetryError(t *testing.T) { streamErr := &WillRetryError{ErrStr: "stream error", RetryAttempt: 2} sr, sw := schema.Pipe[Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("partial1", nil), nil) sw.Send(schema.AssistantMessage("partial2", nil), nil) sw.Send(nil, streamErr) }() wrapper := &agentEventWrapper{ AgentEvent: &AgentEvent{ AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, MessageStream: sr, }, }, }, TS: 12345, } _, err := getMessageFromWrappedEvent(wrapper) assert.NotNil(t, err) var wrapperErr *WillRetryError assert.True(t, errors.As(wrapper.StreamErr, &wrapperErr)) assert.Equal(t, streamErr.ErrStr, wrapperErr.ErrStr) assert.Equal(t, streamErr.RetryAttempt, wrapperErr.RetryAttempt) var buf bytes.Buffer enc := gob.NewEncoder(&buf) err = enc.Encode(wrapper) assert.NoError(t, err) var decoded agentEventWrapper dec := gob.NewDecoder(&buf) err = dec.Decode(&decoded) assert.NoError(t, err) assert.Equal(t, "TestAgent", decoded.AgentName) assert.Equal(t, int64(12345), decoded.TS) var decodedErr *WillRetryError assert.True(t, errors.As(decoded.StreamErr, &decodedErr)) assert.Equal(t, streamErr.ErrStr, decodedErr.ErrStr) assert.Equal(t, streamErr.RetryAttempt, decodedErr.RetryAttempt) } func TestAgentEventWrapper_GobEncoding_WithUnregisteredError(t *testing.T) { streamErr := errors.New("unregistered error type") sr, sw := schema.Pipe[Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("partial1", nil), nil) sw.Send(nil, streamErr) }() wrapper := &agentEventWrapper{ AgentEvent: &AgentEvent{ AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, MessageStream: sr, }, }, }, TS: 22222, } _, err := getMessageFromWrappedEvent(wrapper) assert.NotNil(t, err) assert.Equal(t, streamErr, wrapper.StreamErr) var buf bytes.Buffer enc := gob.NewEncoder(&buf) err = enc.Encode(wrapper) assert.Error(t, err, "gob encoding should fail for unregistered error types") } func TestAgentEventWrapper_GobEncoding_WithStreamSuccess(t *testing.T) { sr, sw := schema.Pipe[Message](10) go func() { defer sw.Close() sw.Send(schema.AssistantMessage("success1", nil), nil) sw.Send(schema.AssistantMessage("success2", nil), nil) }() wrapper := &agentEventWrapper{ AgentEvent: &AgentEvent{ AgentName: "TestAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: true, MessageStream: sr, }, }, }, TS: 67890, } msg, err := getMessageFromWrappedEvent(wrapper) assert.NoError(t, err) assert.Equal(t, "success1success2", msg.Content) var buf bytes.Buffer enc := gob.NewEncoder(&buf) err = enc.Encode(wrapper) assert.NoError(t, err) var decoded agentEventWrapper dec := gob.NewDecoder(&buf) err = dec.Decode(&decoded) assert.NoError(t, err) assert.Equal(t, "TestAgent", decoded.AgentName) assert.Equal(t, int64(67890), decoded.TS) assert.Empty(t, decoded.StreamErr) } ================================================ FILE: adk/workflow.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "fmt" "runtime/debug" "sync" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) type workflowAgentMode int const ( workflowAgentModeUnknown workflowAgentMode = iota workflowAgentModeSequential workflowAgentModeLoop workflowAgentModeParallel ) type workflowAgent struct { name string description string subAgents []*flowAgent mode workflowAgentMode maxIterations int } func (a *workflowAgent) Name(_ context.Context) string { return a.name } func (a *workflowAgent) Description(_ context.Context) string { return a.description } func (a *workflowAgent) GetType() string { switch a.mode { case workflowAgentModeSequential: return "Sequential" case workflowAgentModeParallel: return "Parallel" case workflowAgentModeLoop: return "Loop" default: return "WorkflowAgent" } } func (a *workflowAgent) Run(ctx context.Context, _ *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { var err error defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&AgentEvent{Err: e}) } else if err != nil { generator.Send(&AgentEvent{Err: err}) } generator.Close() }() // Different workflow execution based on mode switch a.mode { case workflowAgentModeSequential: err = a.runSequential(ctx, generator, nil, nil, opts...) case workflowAgentModeLoop: err = a.runLoop(ctx, generator, nil, nil, opts...) case workflowAgentModeParallel: err = a.runParallel(ctx, generator, nil, nil, opts...) default: err = fmt.Errorf("unsupported workflow agent mode: %d", a.mode) } }() return iterator } type sequentialWorkflowState struct { InterruptIndex int } type parallelWorkflowState struct { SubAgentEvents map[int][]*agentEventWrapper } type loopWorkflowState struct { LoopIterations int SubAgentIndex int } func init() { schema.RegisterName[*sequentialWorkflowState]("eino_adk_sequential_workflow_state") schema.RegisterName[*parallelWorkflowState]("eino_adk_parallel_workflow_state") schema.RegisterName[*loopWorkflowState]("eino_adk_loop_workflow_state") } func (a *workflowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { var err error defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&AgentEvent{Err: e}) } else if err != nil { generator.Send(&AgentEvent{Err: err}) } generator.Close() }() state := info.InterruptState if state == nil { panic(fmt.Sprintf("workflowAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) } // Different workflow execution based on the type of our restored state. switch s := state.(type) { case *sequentialWorkflowState: err = a.runSequential(ctx, generator, s, info, opts...) case *parallelWorkflowState: err = a.runParallel(ctx, generator, s, info, opts...) case *loopWorkflowState: err = a.runLoop(ctx, generator, s, info, opts...) default: err = fmt.Errorf("unsupported workflow agent state type: %T", s) } }() return iterator } // WorkflowInterruptInfo CheckpointSchema: persisted via InterruptInfo.Data (gob). type WorkflowInterruptInfo struct { OrigInput *AgentInput SequentialInterruptIndex int SequentialInterruptInfo *InterruptInfo LoopIterations int ParallelInterruptInfo map[int] /*index*/ *InterruptInfo } func (a *workflowAgent) runSequential(ctx context.Context, generator *AsyncGenerator[*AgentEvent], seqState *sequentialWorkflowState, info *ResumeInfo, opts ...AgentRunOption) (err error) { startIdx := 0 // seqCtx tracks the accumulated RunPath across the sequence. seqCtx := ctx // If we are resuming, find which sub-agent to start from and prepare its context. if seqState != nil { startIdx = seqState.InterruptIndex var steps []string for i := 0; i < startIdx; i++ { steps = append(steps, a.subAgents[i].Name(seqCtx)) } seqCtx = updateRunPathOnly(seqCtx, steps...) } for i := startIdx; i < len(a.subAgents); i++ { subAgent := a.subAgents[i] var subIterator *AsyncIterator[*AgentEvent] if seqState != nil { subIterator = subAgent.Resume(seqCtx, &ResumeInfo{ EnableStreaming: info.EnableStreaming, InterruptInfo: info.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, }, opts...) seqState = nil } else { subIterator = subAgent.Run(seqCtx, nil, opts...) } seqCtx = updateRunPathOnly(seqCtx, subAgent.Name(seqCtx)) var lastActionEvent *AgentEvent for { event, ok := subIterator.Next() if !ok { break } if event.Err != nil { // exit if report error generator.Send(event) return nil } if lastActionEvent != nil { generator.Send(lastActionEvent) lastActionEvent = nil } if event.Action != nil { lastActionEvent = event continue } generator.Send(event) } if lastActionEvent != nil { if lastActionEvent.Action.internalInterrupted != nil { // A sub-agent interrupted. Wrap it with our own state, including the index. state := &sequentialWorkflowState{ InterruptIndex: i, } // Use CompositeInterrupt to funnel the sub-interrupt and add our own state. // The context for the composite interrupt must be the one from *before* the sub-agent ran. event := CompositeInterrupt(ctx, "Sequential workflow interrupted", state, lastActionEvent.Action.internalInterrupted) // For backward compatibility, populate the deprecated Data field. event.Action.Interrupted.Data = &WorkflowInterruptInfo{ OrigInput: getRunCtx(ctx).RootInput, SequentialInterruptIndex: i, SequentialInterruptInfo: lastActionEvent.Action.Interrupted, } event.AgentName = lastActionEvent.AgentName event.RunPath = lastActionEvent.RunPath generator.Send(event) return nil } if lastActionEvent.Action.Exit { // Forward the event generator.Send(lastActionEvent) return nil } generator.Send(lastActionEvent) } } return nil } // BreakLoopAction is a programmatic-only agent action used to prematurely // terminate the execution of a loop workflow agent. // When a loop workflow agent receives this action from a sub-agent, it will stop its // current iteration and will not proceed to the next one. // It will mark the BreakLoopAction as Done, signalling to any 'upper level' loop agent // that this action has been processed and should be ignored further up. // This action is not intended to be used by LLMs. type BreakLoopAction struct { // From records the name of the agent that initiated the break loop action. From string // Done is a state flag that can be used by the framework to mark when the // action has been handled. Done bool // CurrentIterations is populated by the framework to record at which // iteration the loop was broken. CurrentIterations int } // NewBreakLoopAction creates a new BreakLoopAction, signaling a request // to terminate the current loop. func NewBreakLoopAction(agentName string) *AgentAction { return &AgentAction{BreakLoop: &BreakLoopAction{ From: agentName, }} } func (a *workflowAgent) runLoop(ctx context.Context, generator *AsyncGenerator[*AgentEvent], loopState *loopWorkflowState, resumeInfo *ResumeInfo, opts ...AgentRunOption) (err error) { if len(a.subAgents) == 0 { return nil } startIter := 0 startIdx := 0 // loopCtx tracks the accumulated RunPath across the full sequence within a single iteration. loopCtx := ctx if loopState != nil { // We are resuming. startIter = loopState.LoopIterations startIdx = loopState.SubAgentIndex // Rebuild the loopCtx to have the correct RunPath up to the point of resumption. var steps []string for i := 0; i < startIter; i++ { for _, subAgent := range a.subAgents { steps = append(steps, subAgent.Name(loopCtx)) } } for i := 0; i < startIdx; i++ { steps = append(steps, a.subAgents[i].Name(loopCtx)) } loopCtx = updateRunPathOnly(loopCtx, steps...) } for i := startIter; i < a.maxIterations || a.maxIterations == 0; i++ { for j := startIdx; j < len(a.subAgents); j++ { subAgent := a.subAgents[j] var subIterator *AsyncIterator[*AgentEvent] if loopState != nil { // This is the agent we need to resume. subIterator = subAgent.Resume(loopCtx, &ResumeInfo{ EnableStreaming: resumeInfo.EnableStreaming, InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, }, opts...) loopState = nil // Only resume the first time. } else { subIterator = subAgent.Run(loopCtx, nil, opts...) } loopCtx = updateRunPathOnly(loopCtx, subAgent.Name(loopCtx)) var lastActionEvent *AgentEvent var breakLoopEvent *AgentEvent for { event, ok := subIterator.Next() if !ok { break } if event.Err != nil { generator.Send(event) return nil } if lastActionEvent != nil { if lastActionEvent.Action.BreakLoop != nil && !lastActionEvent.Action.BreakLoop.Done { lastActionEvent.Action.BreakLoop.Done = true lastActionEvent.Action.BreakLoop.CurrentIterations = i breakLoopEvent = lastActionEvent } generator.Send(lastActionEvent) lastActionEvent = nil } if event.Action != nil { lastActionEvent = event continue } generator.Send(event) } if lastActionEvent != nil { if lastActionEvent.Action.BreakLoop != nil && !lastActionEvent.Action.BreakLoop.Done { lastActionEvent.Action.BreakLoop.Done = true lastActionEvent.Action.BreakLoop.CurrentIterations = i breakLoopEvent = lastActionEvent } if lastActionEvent.Action.internalInterrupted != nil { // A sub-agent interrupted. Wrap it with our own loop state. state := &loopWorkflowState{ LoopIterations: i, SubAgentIndex: j, } // Use CompositeInterrupt to funnel the sub-interrupt and add our own state. event := CompositeInterrupt(ctx, "Loop workflow interrupted", state, lastActionEvent.Action.internalInterrupted) // For backward compatibility, populate the deprecated Data field. event.Action.Interrupted.Data = &WorkflowInterruptInfo{ OrigInput: getRunCtx(ctx).RootInput, LoopIterations: i, SequentialInterruptIndex: j, SequentialInterruptInfo: lastActionEvent.Action.Interrupted, } event.AgentName = lastActionEvent.AgentName event.RunPath = lastActionEvent.RunPath generator.Send(event) return } if lastActionEvent.Action.Exit { generator.Send(lastActionEvent) return } generator.Send(lastActionEvent) } if breakLoopEvent != nil { return } } // Reset the sub-agent index for the next iteration of the outer loop. startIdx = 0 } return nil } func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerator[*AgentEvent], parState *parallelWorkflowState, resumeInfo *ResumeInfo, opts ...AgentRunOption) error { if len(a.subAgents) == 0 { return nil } var ( wg sync.WaitGroup subInterruptSignals []*core.InterruptSignal dataMap = make(map[int]*InterruptInfo) mu sync.Mutex agentNames map[string]bool err error childContexts = make([]context.Context, len(a.subAgents)) ) // If resuming, get the scoped ResumeInfo for each child that needs to be resumed. if parState != nil { agentNames, err = getNextResumeAgents(ctx, resumeInfo) if err != nil { return err } } // Fork contexts for each sub-agent for i := range a.subAgents { childContexts[i] = forkRunCtx(ctx) // If we're resuming and this agent has existing events, add them to the child context if parState != nil && parState.SubAgentEvents != nil { if existingEvents, ok := parState.SubAgentEvents[i]; ok { // Add existing events to the child's lane events childRunCtx := getRunCtx(childContexts[i]) if childRunCtx != nil && childRunCtx.Session != nil { if childRunCtx.Session.LaneEvents == nil { childRunCtx.Session.LaneEvents = &laneEvents{} } childRunCtx.Session.LaneEvents.Events = append(childRunCtx.Session.LaneEvents.Events, existingEvents...) } } } } for i := range a.subAgents { wg.Add(1) go func(idx int, agent *flowAgent) { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) generator.Send(&AgentEvent{Err: e}) } wg.Done() }() var iterator *AsyncIterator[*AgentEvent] if _, ok := agentNames[agent.Name(ctx)]; ok { // This branch was interrupted and needs to be resumed. iterator = agent.Resume(childContexts[idx], &ResumeInfo{ EnableStreaming: resumeInfo.EnableStreaming, InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).ParallelInterruptInfo[idx], }, opts...) } else if parState != nil { // We are resuming, but this child is not in the next points map. // This means it finished successfully, so we don't run it. return } else { iterator = agent.Run(childContexts[idx], nil, opts...) } for { event, ok := iterator.Next() if !ok { break } if event.Action != nil && event.Action.internalInterrupted != nil { mu.Lock() subInterruptSignals = append(subInterruptSignals, event.Action.internalInterrupted) dataMap[idx] = event.Action.Interrupted mu.Unlock() break } generator.Send(event) } }(i, a.subAgents[i]) } wg.Wait() if len(subInterruptSignals) == 0 { // Join all child contexts back to the parent joinRunCtxs(ctx, childContexts...) return nil } if len(subInterruptSignals) > 0 { // Before interrupting, collect the current events from each child context subAgentEvents := make(map[int][]*agentEventWrapper) for i, childCtx := range childContexts { childRunCtx := getRunCtx(childCtx) if childRunCtx != nil && childRunCtx.Session != nil && childRunCtx.Session.LaneEvents != nil { subAgentEvents[i] = childRunCtx.Session.LaneEvents.Events } } state := ¶llelWorkflowState{ SubAgentEvents: subAgentEvents, } event := CompositeInterrupt(ctx, "Parallel workflow interrupted", state, subInterruptSignals...) // For backward compatibility, populate the deprecated Data field. event.Action.Interrupted.Data = &WorkflowInterruptInfo{ OrigInput: getRunCtx(ctx).RootInput, ParallelInterruptInfo: dataMap, } event.AgentName = a.Name(ctx) event.RunPath = getRunCtx(ctx).RunPath generator.Send(event) } return nil } type SequentialAgentConfig struct { Name string Description string SubAgents []Agent } type ParallelAgentConfig struct { Name string Description string SubAgents []Agent } type LoopAgentConfig struct { Name string Description string SubAgents []Agent MaxIterations int } func newWorkflowAgent(ctx context.Context, name, desc string, subAgents []Agent, mode workflowAgentMode, maxIterations int) (*flowAgent, error) { wa := &workflowAgent{ name: name, description: desc, mode: mode, maxIterations: maxIterations, } fas := make([]Agent, len(subAgents)) for i, subAgent := range subAgents { fas[i] = toFlowAgent(ctx, subAgent, WithDisallowTransferToParent()) } fa, err := setSubAgents(ctx, wa, fas) if err != nil { return nil, err } wa.subAgents = fa.subAgents return fa, nil } // NewSequentialAgent creates an agent that runs sub-agents sequentially. func NewSequentialAgent(ctx context.Context, config *SequentialAgentConfig) (ResumableAgent, error) { return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeSequential, 0) } // NewParallelAgent creates an agent that runs sub-agents in parallel. func NewParallelAgent(ctx context.Context, config *ParallelAgentConfig) (ResumableAgent, error) { return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeParallel, 0) } // NewLoopAgent creates an agent that loops over sub-agents with a max iteration limit. func NewLoopAgent(ctx context.Context, config *LoopAgentConfig) (ResumableAgent, error) { return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeLoop, config.MaxIterations) } ================================================ FILE: adk/workflow_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "fmt" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // mockAgent is a simple implementation of the Agent interface for testing type mockAgent struct { name string description string responses []*AgentEvent } func (a *mockAgent) Name(_ context.Context) string { return a.name } func (a *mockAgent) Description(_ context.Context) string { return a.description } func (a *mockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { defer generator.Close() for _, event := range a.responses { generator.Send(event) // If the event has an Exit action, stop sending events if event.Action != nil && event.Action.Exit { break } } }() return iterator } // newMockAgent creates a new mock agent with the given name, description, and responses func newMockAgent(name, description string, responses []*AgentEvent) *mockAgent { return &mockAgent{ name: name, description: description, responses: responses, } } // TestSequentialAgent tests the sequential workflow agent func TestSequentialAgent(t *testing.T) { ctx := context.Background() // Create mock agents with predefined responses agent1 := newMockAgent("Agent1", "First agent", []*AgentEvent{ { AgentName: "Agent1", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response from Agent1", nil), Role: schema.Assistant, }, }, }, }) agent2 := newMockAgent("Agent2", "Second agent", []*AgentEvent{ { AgentName: "Agent2", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response from Agent2", nil), Role: schema.Assistant, }, }}, }) // Create a sequential agent with the mock agents config := &SequentialAgentConfig{ Name: "SequentialTestAgent", Description: "Test sequential agent", SubAgents: []Agent{agent1, agent2}, } sequentialAgent, err := NewSequentialAgent(ctx, config) assert.NoError(t, err) assert.NotNil(t, sequentialAgent) assert.Equal(t, "Test sequential agent", sequentialAgent.Description(ctx)) // Run the sequential agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } // Initialize the run context ctx, _ = initRunCtx(ctx, sequentialAgent.Name(ctx), input) iterator := sequentialAgent.Run(ctx, input) assert.NotNil(t, iterator) // First event should be from agent1 event1, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event1) assert.Nil(t, event1.Err) assert.NotNil(t, event1.Output) assert.NotNil(t, event1.Output.MessageOutput) // Get the message content from agent1 msg1 := event1.Output.MessageOutput.Message assert.NotNil(t, msg1) assert.Equal(t, "Response from Agent1", msg1.Content) // Second event should be from agent2 event2, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event2) assert.Nil(t, event2.Err) assert.NotNil(t, event2.Output) assert.NotNil(t, event2.Output.MessageOutput) // Get the message content from agent2 msg2 := event2.Output.MessageOutput.Message assert.NotNil(t, msg2) assert.Equal(t, "Response from Agent2", msg2.Content) // No more events _, ok = iterator.Next() assert.False(t, ok) } // TestSequentialAgentWithExit tests the sequential workflow agent with an exit action func TestSequentialAgentWithExit(t *testing.T) { ctx := context.Background() // Create mock agents with predefined responses agent1 := newMockAgent("Agent1", "First agent", []*AgentEvent{ { AgentName: "Agent1", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response from Agent1", nil), Role: schema.Assistant, }, }, Action: &AgentAction{ Exit: true, }, }, }) agent2 := newMockAgent("Agent2", "Second agent", []*AgentEvent{ { AgentName: "Agent2", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response from Agent2", nil), Role: schema.Assistant, }, }, }, }) // Create a sequential agent with the mock agents config := &SequentialAgentConfig{ Name: "SequentialTestAgent", Description: "Test sequential agent", SubAgents: []Agent{agent1, agent2}, } sequentialAgent, err := NewSequentialAgent(ctx, config) assert.NoError(t, err) assert.NotNil(t, sequentialAgent) // Run the sequential agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } ctx, _ = initRunCtx(ctx, sequentialAgent.Name(ctx), input) iterator := sequentialAgent.Run(ctx, input) assert.NotNil(t, iterator) // First event should be from agent1 with exit action event1, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event1) assert.Nil(t, event1.Err) assert.NotNil(t, event1.Output) assert.NotNil(t, event1.Output.MessageOutput) assert.NotNil(t, event1.Action) assert.True(t, event1.Action.Exit) // No more events due to exit action _, ok = iterator.Next() assert.False(t, ok) } // TestParallelAgent tests the parallel workflow agent func TestParallelAgent(t *testing.T) { ctx := context.Background() // Create mock agents with predefined responses agent1 := newMockAgent("Agent1", "First agent", []*AgentEvent{ { AgentName: "Agent1", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response from Agent1", nil), Role: schema.Assistant, }, }, }, }) agent2 := newMockAgent("Agent2", "Second agent", []*AgentEvent{ { AgentName: "Agent2", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Response from Agent2", nil), Role: schema.Assistant, }, }, }, }) // Create a parallel agent with the mock agents config := &ParallelAgentConfig{ Name: "ParallelTestAgent", Description: "Test parallel agent", SubAgents: []Agent{agent1, agent2}, } parallelAgent, err := NewParallelAgent(ctx, config) assert.NoError(t, err) assert.NotNil(t, parallelAgent) // Run the parallel agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } ctx, _ = initRunCtx(ctx, parallelAgent.Name(ctx), input) iterator := parallelAgent.Run(ctx, input) assert.NotNil(t, iterator) // Collect all events var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } // Should have two events, one from each agent assert.Equal(t, 2, len(events)) // Verify the events for _, event := range events { assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) msg := event.Output.MessageOutput.Message assert.NotNil(t, msg) assert.NoError(t, err) // Check the source agent name and message content if event.AgentName == "Agent1" { assert.Equal(t, "Response from Agent1", msg.Content) } else if event.AgentName == "Agent2" { assert.Equal(t, "Response from Agent2", msg.Content) } else { t.Fatalf("Unexpected source agent name: %s", event.AgentName) } } } // TestLoopAgent tests the loop workflow agent func TestLoopAgent(t *testing.T) { ctx := context.Background() // Create a mock agent that will be called multiple times agent := newMockAgent("LoopAgent", "Loop agent", []*AgentEvent{ { AgentName: "LoopAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Loop iteration", nil), Role: schema.Assistant, }, }, }, }) // Create a loop agent with the mock agent and max iterations set to 3 config := &LoopAgentConfig{ Name: "LoopTestAgent", Description: "Test loop agent", SubAgents: []Agent{agent}, MaxIterations: 3, } loopAgent, err := NewLoopAgent(ctx, config) assert.NoError(t, err) assert.NotNil(t, loopAgent) // Run the loop agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } ctx, _ = initRunCtx(ctx, loopAgent.Name(ctx), input) iterator := loopAgent.Run(ctx, input) assert.NotNil(t, iterator) // Collect all events var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } // Should have 3 events (one for each iteration) assert.Equal(t, 3, len(events)) // Verify all events for _, event := range events { assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) msg := event.Output.MessageOutput.Message assert.NotNil(t, msg) assert.Equal(t, "Loop iteration", msg.Content) } } // TestLoopAgentWithBreakLoop tests the loop workflow agent with an break loop action func TestLoopAgentWithBreakLoop(t *testing.T) { ctx := context.Background() // Create a mock agent that will break the loop after the first iteration agent := newMockAgent("LoopAgent", "Loop agent", []*AgentEvent{ { AgentName: "LoopAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Loop iteration with break loop", nil), Role: schema.Assistant, }, }, Action: NewBreakLoopAction("LoopAgent"), }, }) // Create a loop agent with the mock agent and max iterations set to 3 config := &LoopAgentConfig{ Name: "LoopTestAgent", Description: "Test loop agent", SubAgents: []Agent{agent}, MaxIterations: 3, } loopAgent, err := NewLoopAgent(ctx, config) assert.NoError(t, err) assert.NotNil(t, loopAgent) // Run the loop agent input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } ctx, _ = initRunCtx(ctx, loopAgent.Name(ctx), input) iterator := loopAgent.Run(ctx, input) assert.NotNil(t, iterator) // Collect all events var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } // Should have only 1 event due to break loop action assert.Equal(t, 1, len(events)) // Verify the event event := events[0] assert.Nil(t, event.Err) assert.NotNil(t, event.Output) assert.NotNil(t, event.Output.MessageOutput) assert.NotNil(t, event.Action) assert.NotNil(t, event.Action.BreakLoop) assert.True(t, event.Action.BreakLoop.Done) assert.Equal(t, "LoopAgent", event.Action.BreakLoop.From) assert.Equal(t, 0, event.Action.BreakLoop.CurrentIterations) msg := event.Output.MessageOutput.Message assert.NotNil(t, msg) assert.Equal(t, "Loop iteration with break loop", msg.Content) } // Add these test functions to the existing workflow_test.go file // Replace the existing TestWorkflowAgentPanicRecovery function func TestWorkflowAgentPanicRecovery(t *testing.T) { ctx := context.Background() // Create a panic agent that panics in Run method panicAgent := &panicMockAgent{ mockAgent: mockAgent{ name: "PanicAgent", description: "Agent that panics", responses: []*AgentEvent{}, }, } // Create a sequential agent with the panic agent config := &SequentialAgentConfig{ Name: "PanicTestAgent", Description: "Test agent with panic", SubAgents: []Agent{panicAgent}, } sequentialAgent, err := NewSequentialAgent(ctx, config) assert.NoError(t, err) // Run the agent and expect panic recovery input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } ctx, _ = initRunCtx(ctx, sequentialAgent.Name(ctx), input) iterator := sequentialAgent.Run(ctx, input) assert.NotNil(t, iterator) // Should receive an error event due to panic recovery event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.NotNil(t, event.Err) assert.Contains(t, event.Err.Error(), "panic") // No more events _, ok = iterator.Next() assert.False(t, ok) } // Add these new mock agent types that properly panic type panicMockAgent struct { mockAgent } func (a *panicMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { panic("test panic in agent") } func TestParallelWorkflowResumeWithEvents(t *testing.T) { ctx := context.Background() // Create interruptible agents sa1 := &myAgent{ name: "sa1", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() // Send a normal message event first, called event1 generator.Send(&AgentEvent{ AgentName: "sa1", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa1 normal message"), }, }, }) intEvent := Interrupt(ctx, "sa1 interrupt data") generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.True(t, info.WasInterrupted) assert.Nil(t, info.InterruptState) assert.True(t, info.IsResumeTarget) assert.Equal(t, "resume sa1", info.ResumeData) // Get the events from session and verify visibility runCtx := getRunCtx(ctx) assert.NotNil(t, runCtx.Session, "sa1 resumer should have session") allEvents := runCtx.Session.getEvents() // Assert that allEvents only have 1 event, that is event1 assert.Equal(t, 1, len(allEvents), "sa1 should only see its own event in session") assert.Equal(t, "sa1", allEvents[0].AgentEvent.AgentName, "sa1 should see its own event") assert.Equal(t, "sa1 normal message", allEvents[0].AgentEvent.Output.MessageOutput.Message.Content, "sa1 should see its own message content") iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Close() return iter }, } sa2 := &myAgent{ name: "sa2", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() // Send a normal message event first, called event2 generator.Send(&AgentEvent{ AgentName: "sa2", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa2 normal message"), }, }, }) intEvent := StatefulInterrupt(ctx, "sa2 interrupt data", "sa2 interrupt") generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.True(t, info.WasInterrupted) assert.NotNil(t, info.InterruptState) assert.Equal(t, "sa2 interrupt", info.InterruptState) assert.True(t, info.IsResumeTarget) assert.Equal(t, "resume sa2", info.ResumeData) // Get the events from session and verify visibility runCtx := getRunCtx(ctx) assert.NotNil(t, runCtx.Session, "sa2 resumer should have session") allEvents := runCtx.Session.getEvents() // Assert that allEvents only have 1 event, that is event2 assert.Equal(t, 1, len(allEvents), "sa2 should only see its own event in session") assert.Equal(t, "sa2", allEvents[0].AgentEvent.AgentName, "sa2 should see its own event") assert.Equal(t, "sa2 normal message", allEvents[0].AgentEvent.Output.MessageOutput.Message.Content, "sa2 should see its own message content") iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Close() return iter }, } sa3 := &myAgent{ name: "sa3", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "sa3", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa3 completed"), }, }, }) generator.Close() return iter }, } sa4 := &myAgent{ name: "sa4", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "sa4", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("sa4 completed"), }, }, }) generator.Close() return iter }, } t.Run("test parallel workflow agent", func(t *testing.T) { // parallel a, err := NewParallelAgent(ctx, &ParallelAgentConfig{ Name: "parallel agent", SubAgents: []Agent{sa1, sa2, sa3, sa4}, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: a, CheckPointStore: newMyStore(), }) iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) var ( events []*AgentEvent interruptEvent *AgentEvent ) for { event, ok := iter.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event continue } events = append(events, event) } assert.Equal(t, 4, len(events), "should have 4 events (2 normal messages + 2 completed agents)") // Verify specific properties of each event var sa3Event, sa4Event *AgentEvent for _, event := range events { if event.AgentName == "sa3" { sa3Event = event } else if event.AgentName == "sa4" { sa4Event = event } } // Verify sa3 event properties assert.NotNil(t, sa3Event, "should have event from sa3") assert.Equal(t, "sa3", sa3Event.AgentName, "sa3 event should have correct agent name") assert.Equal(t, []RunStep{{"parallel agent"}, {"sa3"}}, sa3Event.RunPath, "sa3 event should have correct run path") assert.NotNil(t, sa3Event.Output, "sa3 event should have output") assert.NotNil(t, sa3Event.Output.MessageOutput, "sa3 event should have message output") assert.Equal(t, "sa3 completed", sa3Event.Output.MessageOutput.Message.Content, "sa3 event should have correct message content") // Verify sa4 event properties assert.NotNil(t, sa4Event, "should have event from sa4") assert.Equal(t, "sa4", sa4Event.AgentName, "sa4 event should have correct agent name") assert.Equal(t, []RunStep{{"parallel agent"}, {"sa4"}}, sa4Event.RunPath, "sa4 event should have correct run path") assert.NotNil(t, sa4Event.Output, "sa4 event should have output") assert.NotNil(t, sa4Event.Output.MessageOutput, "sa4 event should have message output") assert.Equal(t, "sa4 completed", sa4Event.Output.MessageOutput.Message.Content, "sa4 event should have correct message content") assert.NotNil(t, interruptEvent) assert.Equal(t, "parallel agent", interruptEvent.AgentName) assert.Equal(t, []RunStep{{"parallel agent"}}, interruptEvent.RunPath) assert.NotNil(t, interruptEvent.Action.Interrupted) var sa1InfoFound, sa2InfoFound bool for _, ctx := range interruptEvent.Action.Interrupted.InterruptContexts { if ctx.Info == "sa1 interrupt data" { sa1InfoFound = true } else if ctx.Info == "sa2 interrupt data" { sa2InfoFound = true } } assert.Equal(t, 2, len(interruptEvent.Action.Interrupted.InterruptContexts)) assert.True(t, sa1InfoFound) assert.True(t, sa2InfoFound) var parallelInterruptID1, parallelInterruptID2 string for _, ctx := range interruptEvent.Action.Interrupted.InterruptContexts { if ctx.Info == "sa1 interrupt data" { parallelInterruptID1 = ctx.ID } else if ctx.Info == "sa2 interrupt data" { parallelInterruptID2 = ctx.ID } } assert.NotEmpty(t, parallelInterruptID1) assert.NotEmpty(t, parallelInterruptID2) iter, err = runner.ResumeWithParams(ctx, "1", &ResumeParams{ Targets: map[string]any{ parallelInterruptID1: "resume sa1", parallelInterruptID2: "resume sa2", }, }) assert.NoError(t, err) _, ok := iter.Next() assert.False(t, ok) }) } func TestNestedParallelWorkflow(t *testing.T) { ctx := context.Background() // Create predecessor agent that runs before the parallel structure predecessorAgent := &myAgent{ name: "predecessor", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "predecessor", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("predecessor completed"), }, }, }) generator.Close() return iter }, } // Create interruptible inner agents innerAgent1 := &myAgent{ name: "inner1", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() // Verify inner1 can see predecessor's event runCtx := getRunCtx(ctx) allEvents := runCtx.Session.getEvents() assert.Equal(t, 1, len(allEvents), "inner1 should see exactly 1 event (predecessor)") assert.Equal(t, "predecessor", allEvents[0].AgentEvent.AgentName, "inner1 should see predecessor event") assert.Equal(t, "predecessor completed", allEvents[0].AgentEvent.Output.MessageOutput.Message.Content, "inner1 should see predecessor message content") generator.Send(&AgentEvent{ AgentName: "inner1", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("inner1 normal"), }, }, }) intEvent := Interrupt(ctx, "inner1 interrupt") generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.True(t, info.WasInterrupted) assert.Equal(t, "resume inner1", info.ResumeData) // Verify inner1 can see predecessor's event during resume runCtx := getRunCtx(ctx) allEvents := runCtx.Session.getEvents() assert.Equal(t, 2, len(allEvents), "inner1 should see exactly 2 events (predecessor + own normal message) during resume") // Find and verify predecessor event var foundPredecessor bool for _, event := range allEvents { if event.AgentEvent != nil && event.AgentEvent.AgentName == "predecessor" { foundPredecessor = true assert.Equal(t, "predecessor completed", event.AgentEvent.Output.MessageOutput.Message.Content) } } assert.True(t, foundPredecessor, "inner1 should see predecessor event during resume") iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Close() return iter }, } innerAgent2 := &myAgent{ name: "inner2", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() // Verify inner2 can see predecessor's event runCtx := getRunCtx(ctx) allEvents := runCtx.Session.getEvents() assert.Equal(t, 1, len(allEvents), "inner2 should see exactly 1 event (predecessor)") assert.Equal(t, "predecessor", allEvents[0].AgentEvent.AgentName, "inner2 should see predecessor event") assert.Equal(t, "predecessor completed", allEvents[0].AgentEvent.Output.MessageOutput.Message.Content, "inner2 should see predecessor message content") generator.Send(&AgentEvent{ AgentName: "inner2", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("inner2 normal"), }, }, }) intEvent := StatefulInterrupt(ctx, "inner2 interrupt", "inner2 state") generator.Send(intEvent) generator.Close() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { assert.True(t, info.WasInterrupted) assert.Equal(t, "inner2 state", info.InterruptState) assert.Equal(t, "resume inner2", info.ResumeData) // Verify inner2 can see predecessor's event during resume runCtx := getRunCtx(ctx) allEvents := runCtx.Session.getEvents() assert.Equal(t, 2, len(allEvents), "inner2 should see exactly 2 events (predecessor + own normal message) during resume") // Find and verify predecessor event var foundPredecessor bool for _, event := range allEvents { if event.AgentEvent != nil && event.AgentEvent.AgentName == "predecessor" { foundPredecessor = true assert.Equal(t, "predecessor completed", event.AgentEvent.Output.MessageOutput.Message.Content) } } assert.True(t, foundPredecessor, "inner2 should see predecessor event during resume") iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Close() return iter }, } // Create inner parallel workflow innerParallel, err := NewParallelAgent(ctx, &ParallelAgentConfig{ Name: "inner parallel", SubAgents: []Agent{innerAgent1, innerAgent2}, }) assert.NoError(t, err) // Create simple outer agents outerAgent1 := &myAgent{ name: "outer1", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "outer1", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("outer1 completed"), }, }, }) generator.Close() return iter }, } outerAgent2 := &myAgent{ name: "outer2", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() generator.Send(&AgentEvent{ AgentName: "outer2", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("outer2 completed"), }, }, }) generator.Close() return iter }, } // Create outer parallel workflow with nested parallel agent outerParallel, err := NewParallelAgent(ctx, &ParallelAgentConfig{ Name: "outer parallel", SubAgents: []Agent{outerAgent1, innerParallel, outerAgent2}, }) assert.NoError(t, err) // Create successor agent that runs after the parallel structure successorAgent := &myAgent{ name: "successor", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() // Verify successor can see all events from predecessor and parallel agents runCtx := getRunCtx(ctx) allEvents := runCtx.Session.getEvents() assert.GreaterOrEqual(t, len(allEvents), 5, "successor should see all events") var foundPredecessor, foundOuter1, foundOuter2, foundInner1, foundInner2 bool for _, event := range allEvents { if event.AgentEvent != nil { switch event.AgentEvent.AgentName { case "predecessor": foundPredecessor = true assert.Equal(t, "predecessor completed", event.AgentEvent.Output.MessageOutput.Message.Content) case "outer1": foundOuter1 = true case "outer2": foundOuter2 = true case "inner1": foundInner1 = true case "inner2": foundInner2 = true } } } assert.True(t, foundPredecessor, "successor should see predecessor event") assert.True(t, foundOuter1, "successor should see outer1 event") assert.True(t, foundOuter2, "successor should see outer2 event") assert.True(t, foundInner1, "successor should see inner1 event") assert.True(t, foundInner2, "successor should see inner2 event") generator.Send(&AgentEvent{ AgentName: "successor", Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.UserMessage("successor completed"), }, }, }) generator.Close() return iter }, } // Create sequential workflow: predecessor -> parallel -> successor sequentialWorkflow, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ Name: "sequential workflow", SubAgents: []Agent{predecessorAgent, outerParallel, successorAgent}, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: sequentialWorkflow, CheckPointStore: newMyStore(), }) iter := runner.Query(ctx, "test nested parallel with predecessor and successor", WithCheckPointID("nested-parallel-test")) var events []*AgentEvent var interruptEvent *AgentEvent for event, ok := iter.Next(); ok; event, ok = iter.Next() { if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event continue } events = append(events, event) } // Should get events from predecessor, outer agents, and inner normal messages (successor doesn't run due to interruption) assert.Equal(t, 5, len(events), "should have 5 events (predecessor + 2 outer + 2 inner)") if interruptEvent == nil { t.Fatal("should have interrupt event") } // Resume the inner parallel workflow var innerInterruptID1, innerInterruptID2 string for _, ctx := range interruptEvent.Action.Interrupted.InterruptContexts { if ctx.Info == "inner1 interrupt" { innerInterruptID1 = ctx.ID } else if ctx.Info == "inner2 interrupt" { innerInterruptID2 = ctx.ID } } iter, err = runner.ResumeWithParams(ctx, "nested-parallel-test", &ResumeParams{ Targets: map[string]any{ innerInterruptID1: "resume inner1", innerInterruptID2: "resume inner2", }, }) assert.NoError(t, err) // Verify resume completes successfully and successor runs var resumeEvents []*AgentEvent for event, ok := iter.Next(); ok; event, ok = iter.Next() { resumeEvents = append(resumeEvents, event) } // Should get successor event after resume assert.Equal(t, 1, len(resumeEvents), "should have successor event after resume") assert.Equal(t, "successor", resumeEvents[0].AgentName) } // TestWorkflowAgentUnsupportedMode tests unsupported workflow mode error (lines 65-71) func TestWorkflowAgentUnsupportedMode(t *testing.T) { ctx := context.Background() // Create a workflow agent with unsupported mode agent := &workflowAgent{ name: "UnsupportedModeAgent", description: "Agent with unsupported mode", subAgents: []*flowAgent{}, mode: workflowAgentMode(999), // Invalid mode } // Run the agent and expect error input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } ctx, _ = initRunCtx(ctx, agent.Name(ctx), input) iterator := agent.Run(ctx, input) assert.NotNil(t, iterator) // Should receive an error event due to unsupported mode event, ok := iterator.Next() assert.True(t, ok) assert.NotNil(t, event) assert.NotNil(t, event.Err) assert.Contains(t, event.Err.Error(), "unsupported workflow agent mode") // No more events _, ok = iterator.Next() assert.False(t, ok) } func TestFilterOptions(t *testing.T) { a1 := &myAgent{ name: "Agent1", runFn: func(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { o := GetImplSpecificOptions[myAgentOptions](nil, opts...) assert.Equal(t, "Agent1", o.value) iter, gen := NewAsyncIteratorPair[*AgentEvent]() gen.Close() return iter }, } a2 := &myAgent{ name: "Agent2", runFn: func(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { o := GetImplSpecificOptions[myAgentOptions](nil, opts...) assert.Equal(t, "Agent2", o.value) iter, gen := NewAsyncIteratorPair[*AgentEvent]() gen.Close() return iter }, } ctx := context.Background() // sequential seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ SubAgents: []Agent{a1, a2}, }) assert.NoError(t, err) iter := seqAgent.Run(ctx, &AgentInput{}, withValue("Agent1").DesignateAgent("Agent1"), withValue("Agent2").DesignateAgent("Agent2")) _, ok := iter.Next() assert.False(t, ok) // parallel parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ SubAgents: []Agent{a1, a2}, }) assert.NoError(t, err) iter = parAgent.Run(ctx, &AgentInput{}, withValue("Agent1").DesignateAgent("Agent1"), withValue("Agent2").DesignateAgent("Agent2")) _, ok = iter.Next() assert.False(t, ok) } func TestLoopAgentWithError(t *testing.T) { ctx := context.Background() iterationCount := 0 agent := &myAgent{ name: "ErrorAgent", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { defer generator.Close() iterationCount++ if iterationCount == 3 { generator.Send(&AgentEvent{Err: fmt.Errorf("error on iteration %d", iterationCount)}) return } generator.Send(&AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.AssistantMessage(fmt.Sprintf("iteration %d", iterationCount), nil), Role: schema.Assistant, }, }, }) }() return iter }, } loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ Name: "LoopErrorTestAgent", SubAgents: []Agent{agent}, MaxIterations: 10, }) assert.NoError(t, err) input := &AgentInput{Messages: []Message{schema.UserMessage("test")}} ctx, _ = initRunCtx(ctx, loopAgent.Name(ctx), input) iterator := loopAgent.Run(ctx, input) var events []*AgentEvent var errorEvent *AgentEvent for { event, ok := iterator.Next() if !ok { break } if event.Err != nil { errorEvent = event } else { events = append(events, event) } } assert.Equal(t, 2, len(events), "should have 2 successful iterations before error") assert.NotNil(t, errorEvent, "should have received error event") assert.Contains(t, errorEvent.Err.Error(), "error on iteration 3") assert.Equal(t, 3, iterationCount, "loop should stop at iteration 3") } func TestWorkflowCallbackHandlerNotDoubled(t *testing.T) { ctx := context.Background() store := newMyStore() var globalCallbackCount int var designatedCallbackCount int var mu sync.Mutex globalHandler := callbacks.NewHandlerBuilder().OnStartFn( func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component == ComponentOfAgent && info.Name == "SubSubAgent" { mu.Lock() globalCallbackCount++ mu.Unlock() } return ctx }).Build() designatedHandler := callbacks.NewHandlerBuilder().OnStartFn( func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Component == ComponentOfAgent && info.Name == "SubSubAgent" { mu.Lock() designatedCallbackCount++ mu.Unlock() } return ctx }).Build() iterationCount := 0 shouldInterrupt := true subSubAgent := &myAgent{ name: "SubSubAgent", runFn: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { defer generator.Close() iterationCount++ if shouldInterrupt && iterationCount == 2 { generator.Send(Interrupt(ctx, "test_interrupt")) return } generator.Send(&AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.AssistantMessage(fmt.Sprintf("iteration %d", iterationCount), nil), Role: schema.Assistant, }, }, }) }() return iter }, resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, generator := NewAsyncIteratorPair[*AgentEvent]() go func() { defer generator.Close() iterationCount++ generator.Send(&AgentEvent{ Output: &AgentOutput{ MessageOutput: &MessageVariant{ Message: schema.AssistantMessage(fmt.Sprintf("resumed iteration %d", iterationCount), nil), Role: schema.Assistant, }, }, }) }() return iter }, } subWorkflow, err := NewLoopAgent(ctx, &LoopAgentConfig{ Name: "SubWorkflow", SubAgents: []Agent{subSubAgent}, MaxIterations: 2, }) assert.NoError(t, err) parentWorkflow, err := NewLoopAgent(ctx, &LoopAgentConfig{ Name: "ParentWorkflow", SubAgents: []Agent{subWorkflow}, MaxIterations: 2, }) assert.NoError(t, err) runner := NewRunner(ctx, RunnerConfig{ Agent: parentWorkflow, CheckPointStore: store, }) opts := []AgentRunOption{ WithCallbacks(globalHandler), WithCallbacks(designatedHandler).DesignateAgent("ParentWorkflow", "SubSubAgent"), WithCheckPointID("cp1"), } iterator := runner.Run(ctx, []Message{schema.UserMessage("test")}, opts...) var interruptEvent *AgentEvent for { event, ok := iterator.Next() if !ok { break } if event.Action != nil && event.Action.Interrupted != nil { interruptEvent = event } } assert.NotNil(t, interruptEvent) assert.Equal(t, 2, iterationCount) assert.Equal(t, 2, globalCallbackCount) assert.Equal(t, 2, designatedCallbackCount) shouldInterrupt = false var rootCauseID string for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { if intCtx.IsRootCause { rootCauseID = intCtx.ID break } } resumeIter, err := runner.ResumeWithParams(ctx, "cp1", &ResumeParams{ Targets: map[string]any{rootCauseID: nil}, }, opts...) assert.NoError(t, err) for { _, ok := resumeIter.Next() if !ok { break } } assert.Equal(t, 5, iterationCount) assert.Equal(t, 5, globalCallbackCount) assert.Equal(t, 5, designatedCallbackCount) } func TestLoopAgentWithBreakLoopFollowedByMoreEvents(t *testing.T) { ctx := context.Background() agent := newMockAgent("SubAgent", "Sub agent", []*AgentEvent{ { AgentName: "SubAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.ToolMessage("tool result", "call_123"), Role: schema.Tool, }, }, Action: NewBreakLoopAction("SubAgent"), }, { AgentName: "SubAgent", Output: &AgentOutput{ MessageOutput: &MessageVariant{ IsStreaming: false, Message: schema.AssistantMessage("Final response after tool", nil), Role: schema.Assistant, }, }, }, }) loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ Name: "LoopTestAgent", Description: "Test loop agent", SubAgents: []Agent{agent}, MaxIterations: 3, }) assert.NoError(t, err) assert.NotNil(t, loopAgent) input := &AgentInput{ Messages: []Message{ schema.UserMessage("Test input"), }, } ctx, _ = initRunCtx(ctx, loopAgent.Name(ctx), input) iterator := loopAgent.Run(ctx, input) assert.NotNil(t, iterator) var events []*AgentEvent for { event, ok := iterator.Next() if !ok { break } events = append(events, event) } assert.Equal(t, 2, len(events), "should have 2 events (tool event with BreakLoop + final response) and loop should break") assert.NotNil(t, events[0].Action, "first event should have an action") assert.NotNil(t, events[0].Action.BreakLoop, "first event should have BreakLoop action") assert.True(t, events[0].Action.BreakLoop.Done, "BreakLoop should be marked as Done") assert.Equal(t, "SubAgent", events[0].Action.BreakLoop.From) assert.Equal(t, 0, events[0].Action.BreakLoop.CurrentIterations) assert.Equal(t, schema.Tool, events[0].Output.MessageOutput.Role, "first event should be tool message") assert.Nil(t, events[1].Action, "second event should not have an action") assert.Equal(t, schema.Assistant, events[1].Output.MessageOutput.Role, "second event should be assistant message") assert.Equal(t, "Final response after tool", events[1].Output.MessageOutput.Message.Content) } ================================================ FILE: adk/wrappers.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "reflect" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) type generateEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) type streamEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) type modelWrapperConfig struct { handlers []ChatModelAgentMiddleware middlewares []AgentMiddleware retryConfig *ModelRetryConfig toolInfos []*schema.ToolInfo } func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model.BaseChatModel { var wrapped model.BaseChatModel = m if !components.IsCallbacksEnabled(m) { wrapped = (&callbackInjectionModelWrapper{}).WrapModel(wrapped) } wrapped = &stateModelWrapper{ inner: wrapped, original: m, handlers: config.handlers, middlewares: config.middlewares, toolInfos: config.toolInfos, modelRetryConfig: config.retryConfig, } return wrapped } type callbackInjectionModelWrapper struct{} func (w *callbackInjectionModelWrapper) WrapModel(m model.BaseChatModel) model.BaseChatModel { return &callbackInjectedModel{inner: m} } type callbackInjectedModel struct { inner model.BaseChatModel } func (m *callbackInjectedModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { ctx = callbacks.OnStart(ctx, input) result, err := m.inner.Generate(ctx, input, opts...) if err != nil { callbacks.OnError(ctx, err) return nil, err } callbacks.OnEnd(ctx, result) return result, nil } func (m *callbackInjectedModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { ctx = callbacks.OnStart(ctx, input) result, err := m.inner.Stream(ctx, input, opts...) if err != nil { callbacks.OnError(ctx, err) return nil, err } _, wrappedStream := callbacks.OnEndWithStreamOutput(ctx, result) return wrappedStream, nil } func handlersToToolMiddlewares(handlers []ChatModelAgentMiddleware) []compose.ToolMiddleware { var middlewares []compose.ToolMiddleware for i := len(handlers) - 1; i >= 0; i-- { handler := handlers[i] m := compose.ToolMiddleware{} h := handler m.Invokable = func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { tCtx := &ToolContext{ Name: input.Name, CallID: input.CallID, } wrappedEndpoint, err := h.WrapInvokableToolCall( ctx, func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { output, err := next(ctx, &compose.ToolInput{ Name: input.Name, CallID: input.CallID, Arguments: argumentsInJSON, CallOptions: opts, }) if err != nil { return "", err } return output.Result, nil }, tCtx, ) if err != nil { return nil, err } result, err := wrappedEndpoint(ctx, input.Arguments, input.CallOptions...) if err != nil { return nil, err } return &compose.ToolOutput{Result: result}, nil } } m.Streamable = func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { tCtx := &ToolContext{ Name: input.Name, CallID: input.CallID, } wrappedEndpoint, err := h.WrapStreamableToolCall( ctx, func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { output, err := next(ctx, &compose.ToolInput{ Name: input.Name, CallID: input.CallID, Arguments: argumentsInJSON, CallOptions: opts, }) if err != nil { return nil, err } return output.Result, nil }, tCtx, ) if err != nil { return nil, err } result, err := wrappedEndpoint(ctx, input.Arguments, input.CallOptions...) if err != nil { return nil, err } return &compose.StreamToolOutput{Result: result}, nil } } m.EnhancedInvokable = func(next compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { tCtx := &ToolContext{ Name: input.Name, CallID: input.CallID, } wrappedEndpoint, err := h.WrapEnhancedInvokableToolCall( ctx, func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { output, err := next(ctx, &compose.ToolInput{ Name: input.Name, CallID: input.CallID, Arguments: toolArgument.Text, CallOptions: opts, }) if err != nil { return nil, err } return output.Result, nil }, tCtx, ) if err != nil { return nil, err } result, err := wrappedEndpoint(ctx, &schema.ToolArgument{Text: input.Arguments}, input.CallOptions...) if err != nil { return nil, err } return &compose.EnhancedInvokableToolOutput{Result: result}, nil } } m.EnhancedStreamable = func(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { tCtx := &ToolContext{ Name: input.Name, CallID: input.CallID, } wrappedEndpoint, err := h.WrapEnhancedStreamableToolCall( ctx, func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { output, err := next(ctx, &compose.ToolInput{ Name: input.Name, CallID: input.CallID, Arguments: toolArgument.Text, CallOptions: opts, }) if err != nil { return nil, err } return output.Result, nil }, tCtx, ) if err != nil { return nil, err } result, err := wrappedEndpoint(ctx, &schema.ToolArgument{Text: input.Arguments}, input.CallOptions...) if err != nil { return nil, err } return &compose.EnhancedStreamableToolOutput{Result: result}, nil } } middlewares = append(middlewares, m) } return middlewares } type eventSenderModelWrapper struct { *BaseChatModelAgentMiddleware } // NewEventSenderModelWrapper returns a ChatModelAgentMiddleware that sends model response events. // By default, the framework applies this wrapper after all user middlewares, so events contain // modified messages. To send events with original (unmodified) output, pass this as a Handler // after the modifying middleware (placing it innermost in the wrapper chain). // When detected in Handlers, the framework skips the default event sender to avoid duplicates. func NewEventSenderModelWrapper() ChatModelAgentMiddleware { return &eventSenderModelWrapper{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, } } func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) { var retryConfig *ModelRetryConfig if mc != nil { retryConfig = mc.ModelRetryConfig } return &eventSenderModel{inner: m, modelRetryConfig: retryConfig}, nil } type eventSenderModel struct { inner model.BaseChatModel modelRetryConfig *ModelRetryConfig } func (m *eventSenderModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { result, err := m.inner.Generate(ctx, input, opts...) if err != nil { return nil, err } execCtx := getChatModelAgentExecCtx(ctx) if execCtx == nil || execCtx.generator == nil { return nil, errors.New("generator is nil when sending event in Generate: ensure agent state is properly initialized") } msgCopy := *result event := EventFromMessage(&msgCopy, nil, schema.Assistant, "") execCtx.send(event) return result, nil } func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { result, err := m.inner.Stream(ctx, input, opts...) if err != nil { return nil, err } execCtx := getChatModelAgentExecCtx(ctx) if execCtx == nil || execCtx.generator == nil { result.Close() return nil, errors.New("generator is nil when sending event in Stream: ensure agent state is properly initialized") } var retryAttempt int _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { retryAttempt = st.getRetryAttempt() return nil }) streams := result.Copy(2) eventStream := streams[0] if m.modelRetryConfig != nil { convertOpts := []schema.ConvertOption{ schema.WithErrWrapper(genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, retryAttempt, m.modelRetryConfig.IsRetryAble)), } eventStream = schema.StreamReaderWithConvert(streams[0], func(msg *schema.Message) (*schema.Message, error) { return msg, nil }, convertOpts...) } event := EventFromMessage(nil, eventStream, schema.Assistant, "") execCtx.send(event) return streams[1], nil } func popToolGenAction(ctx context.Context, toolName string) *AgentAction { toolCallID := compose.GetToolCallID(ctx) var action *AgentAction _ = compose.ProcessState(ctx, func(ctx context.Context, st *State) error { if len(toolCallID) > 0 { if a := st.popToolGenAction(toolCallID); a != nil { action = a return nil } } if a := st.popToolGenAction(toolName); a != nil { action = a } return nil }) return action } type eventSenderToolHandler struct{} func (h *eventSenderToolHandler) WrapInvokableToolCall(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } toolName := input.Name callID := input.CallID prePopAction := popToolGenAction(ctx, toolName) msg := schema.ToolMessage(output.Result, callID, schema.WithToolName(toolName)) event := EventFromMessage(msg, nil, schema.Tool, toolName) if prePopAction != nil { event.Action = prePopAction } execCtx := getChatModelAgentExecCtx(ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) } else { execCtx.send(event) } return nil }) return output, nil } } func (h *eventSenderToolHandler) WrapStreamableToolCall(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } toolName := input.Name callID := input.CallID prePopAction := popToolGenAction(ctx, toolName) streams := output.Result.Copy(2) cvt := func(in string) (Message, error) { return schema.ToolMessage(in, callID, schema.WithToolName(toolName)), nil } msgStream := schema.StreamReaderWithConvert(streams[0], cvt) event := EventFromMessage(nil, msgStream, schema.Tool, toolName) event.Action = prePopAction execCtx := getChatModelAgentExecCtx(ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) } else { execCtx.send(event) } return nil }) return &compose.StreamToolOutput{Result: streams[1]}, nil } } func (h *eventSenderToolHandler) WrapEnhancedInvokableToolCall(next compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } toolName := input.Name callID := input.CallID prePopAction := popToolGenAction(ctx, toolName) msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) msg.UserInputMultiContent, err = output.Result.ToMessageInputParts() if err != nil { return nil, err } event := EventFromMessage(msg, nil, schema.Tool, toolName) if prePopAction != nil { event.Action = prePopAction } execCtx := getChatModelAgentExecCtx(ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) } else { execCtx.send(event) } return nil }) return output, nil } } func (h *eventSenderToolHandler) WrapEnhancedStreamableToolCall(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } toolName := input.Name callID := input.CallID prePopAction := popToolGenAction(ctx, toolName) streams := output.Result.Copy(2) cvt := func(in *schema.ToolResult) (Message, error) { msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) var cvtErr error msg.UserInputMultiContent, cvtErr = in.ToMessageInputParts() if cvtErr != nil { return nil, cvtErr } return msg, nil } msgStream := schema.StreamReaderWithConvert(streams[0], cvt) event := EventFromMessage(nil, msgStream, schema.Tool, toolName) event.Action = prePopAction execCtx := getChatModelAgentExecCtx(ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) } else { execCtx.send(event) } return nil }) return &compose.EnhancedStreamableToolOutput{Result: streams[1]}, nil } } type stateModelWrapper struct { inner model.BaseChatModel original model.BaseChatModel handlers []ChatModelAgentMiddleware middlewares []AgentMiddleware toolInfos []*schema.ToolInfo modelRetryConfig *ModelRetryConfig } func (w *stateModelWrapper) IsCallbacksEnabled() bool { return true } func (w *stateModelWrapper) GetType() string { if typer, ok := w.original.(components.Typer); ok { return typer.GetType() } return generic.ParseTypeName(reflect.ValueOf(w.original)) } func (w *stateModelWrapper) hasUserEventSender() bool { for _, handler := range w.handlers { if _, ok := handler.(*eventSenderModelWrapper); ok { return true } } return false } func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) generateEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] innerEndpoint := endpoint baseToolInfos := w.toolInfos endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err } return wrappedModel.Generate(ctx, input, opts...) } } if !hasUserEventSender { innerEndpoint := endpoint eventSender := NewEventSenderModelWrapper() endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { execCtx := getChatModelAgentExecCtx(ctx) if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } mc := &ModelContext{ModelRetryConfig: retryConfig} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err } return wrappedModel.Generate(ctx, input, opts...) } } if w.modelRetryConfig != nil { innerEndpoint := endpoint endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { retryWrapper := newRetryModelWrapper(&endpointModel{generate: innerEndpoint}, w.modelRetryConfig) return retryWrapper.Generate(ctx, input, opts...) } } return endpoint } func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] innerEndpoint := endpoint baseToolInfos := w.toolInfos endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err } return wrappedModel.Stream(ctx, input, opts...) } } if !hasUserEventSender { innerEndpoint := endpoint eventSender := NewEventSenderModelWrapper() endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { execCtx := getChatModelAgentExecCtx(ctx) if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } mc := &ModelContext{ModelRetryConfig: retryConfig} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err } return wrappedModel.Stream(ctx, input, opts...) } } if w.modelRetryConfig != nil { innerEndpoint := endpoint endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { retryWrapper := newRetryModelWrapper(&endpointModel{stream: innerEndpoint}, w.modelRetryConfig) return retryWrapper.Stream(ctx, input, opts...) } } return endpoint } func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { var stateMessages []Message _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { stateMessages = st.Messages return nil }) state := &ChatModelAgentState{Messages: append(stateMessages, input...)} for _, m := range w.middlewares { if m.BeforeChatModel != nil { if err := m.BeforeChatModel(ctx, state); err != nil { return nil, err } } } baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) if err != nil { return nil, err } } _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { st.Messages = state.Messages return nil }) wrappedEndpoint := w.wrapGenerateEndpoint(w.inner.Generate) result, err := wrappedEndpoint(ctx, state.Messages, opts...) if err != nil { return nil, err } state.Messages = append(state.Messages, result) for _, handler := range w.handlers { ctx, state, err = handler.AfterModelRewriteState(ctx, state, mc) if err != nil { return nil, err } } for _, m := range w.middlewares { if m.AfterChatModel != nil { if err := m.AfterChatModel(ctx, state); err != nil { return nil, err } } } _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { st.Messages = state.Messages return nil }) if len(state.Messages) == 0 { return nil, errors.New("no messages left in state after model call") } return state.Messages[len(state.Messages)-1], nil } func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { var stateMessages []Message _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { stateMessages = st.Messages return nil }) state := &ChatModelAgentState{Messages: append(stateMessages, input...)} for _, m := range w.middlewares { if m.BeforeChatModel != nil { if err := m.BeforeChatModel(ctx, state); err != nil { return nil, err } } } baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) if err != nil { return nil, err } } _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { st.Messages = state.Messages return nil }) wrappedEndpoint := w.wrapStreamEndpoint(w.inner.Stream) stream, err := wrappedEndpoint(ctx, state.Messages, opts...) if err != nil { return nil, err } result, err := schema.ConcatMessageStream(stream) if err != nil { return nil, err } state.Messages = append(state.Messages, result) for _, handler := range w.handlers { ctx, state, err = handler.AfterModelRewriteState(ctx, state, mc) if err != nil { return nil, err } } for _, m := range w.middlewares { if m.AfterChatModel != nil { if err := m.AfterChatModel(ctx, state); err != nil { return nil, err } } } _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { st.Messages = state.Messages return nil }) if len(state.Messages) == 0 { return nil, errors.New("no messages left in state after model call") } return schema.StreamReaderFromArray([]*schema.Message{state.Messages[len(state.Messages)-1]}), nil } type endpointModel struct { generate generateEndpoint stream streamEndpoint } func (m *endpointModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { if m.generate != nil { return m.generate(ctx, input, opts...) } return nil, errors.New("generate endpoint not set") } func (m *endpointModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { if m.stream != nil { return m.stream(ctx, input, opts...) } return nil, errors.New("stream endpoint not set") } ================================================ FILE: adk/wrappers_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adk import ( "context" "errors" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) type testEnhancedToolWrapperHandler struct { *BaseChatModelAgentMiddleware wrapEnhancedInvokableFn func(context.Context, EnhancedInvokableToolCallEndpoint, *ToolContext) EnhancedInvokableToolCallEndpoint wrapEnhancedStreamableFn func(context.Context, EnhancedStreamableToolCallEndpoint, *ToolContext) EnhancedStreamableToolCallEndpoint } func (h *testEnhancedToolWrapperHandler) WrapEnhancedInvokableToolCall(ctx context.Context, endpoint EnhancedInvokableToolCallEndpoint, tCtx *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { if h.wrapEnhancedInvokableFn != nil { return h.wrapEnhancedInvokableFn(ctx, endpoint, tCtx), nil } return endpoint, nil } func (h *testEnhancedToolWrapperHandler) WrapEnhancedStreamableToolCall(ctx context.Context, endpoint EnhancedStreamableToolCallEndpoint, tCtx *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { if h.wrapEnhancedStreamableFn != nil { return h.wrapEnhancedStreamableFn(ctx, endpoint, tCtx), nil } return endpoint, nil } func newTestEnhancedInvokableToolCallWrapper(beforeFn, afterFn func()) func(context.Context, EnhancedInvokableToolCallEndpoint, *ToolContext) EnhancedInvokableToolCallEndpoint { return func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { if beforeFn != nil { beforeFn() } result, err := endpoint(ctx, toolArgument, opts...) if afterFn != nil { afterFn() } return result, err } } } func newTestEnhancedStreamableToolCallWrapper(beforeFn, afterFn func()) func(context.Context, EnhancedStreamableToolCallEndpoint, *ToolContext) EnhancedStreamableToolCallEndpoint { return func(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) EnhancedStreamableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { if beforeFn != nil { beforeFn() } result, err := endpoint(ctx, toolArgument, opts...) if afterFn != nil { afterFn() } return result, err } } } func TestHandlersToToolMiddlewaresEnhanced(t *testing.T) { t.Run("OnlyEnhancedInvokableHandler", func(t *testing.T) { var called bool handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { called = true return endpoint(ctx, toolArgument, opts...) } }, }, } middlewares := handlersToToolMiddlewares(handlers) assert.Len(t, middlewares, 1) assert.NotNil(t, middlewares[0].EnhancedInvokable) assert.NotNil(t, middlewares[0].Invokable) assert.NotNil(t, middlewares[0].Streamable) assert.NotNil(t, middlewares[0].EnhancedStreamable) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{ Result: &schema.ToolResult{ Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "test"}}, }, }, nil } wrapped := middlewares[0].EnhancedInvokable(mockEndpoint) _, err := wrapped(context.Background(), &compose.ToolInput{ Name: "test_tool", CallID: "call-1", Arguments: `{"input": "test"}`, }) assert.NoError(t, err) assert.True(t, called) }) t.Run("OnlyEnhancedStreamableHandler", func(t *testing.T) { var called bool handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedStreamableFn: func(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) EnhancedStreamableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { called = true return endpoint(ctx, toolArgument, opts...) } }, }, } middlewares := handlersToToolMiddlewares(handlers) assert.Len(t, middlewares, 1) assert.NotNil(t, middlewares[0].EnhancedStreamable) assert.NotNil(t, middlewares[0].Invokable) assert.NotNil(t, middlewares[0].Streamable) assert.NotNil(t, middlewares[0].EnhancedInvokable) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { return &compose.EnhancedStreamableToolOutput{ Result: schema.StreamReaderFromArray([]*schema.ToolResult{ {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "test"}}}, }), }, nil } wrapped := middlewares[0].EnhancedStreamable(mockEndpoint) _, err := wrapped(context.Background(), &compose.ToolInput{ Name: "test_tool", CallID: "call-1", Arguments: `{"input": "test"}`, }) assert.NoError(t, err) assert.True(t, called) }) t.Run("MixedHandlers", func(t *testing.T) { var invokableCalled, streamableCalled, enhancedInvokableCalled, enhancedStreamableCalled bool handlers := []ChatModelAgentMiddleware{ &testToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapInvokableFn: func(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) InvokableToolCallEndpoint { return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { invokableCalled = true return endpoint(ctx, argumentsInJSON, opts...) } }, }, &testToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapStreamableFn: func(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) StreamableToolCallEndpoint { return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { streamableCalled = true return endpoint(ctx, argumentsInJSON, opts...) } }, }, &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { enhancedInvokableCalled = true return endpoint(ctx, toolArgument, opts...) } }, }, &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedStreamableFn: func(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) EnhancedStreamableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { enhancedStreamableCalled = true return endpoint(ctx, toolArgument, opts...) } }, }, } middlewares := handlersToToolMiddlewares(handlers) assert.Len(t, middlewares, 4) invokableEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { return &compose.ToolOutput{Result: "test"}, nil } _, _ = middlewares[3].Invokable(invokableEndpoint)(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) streamableEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { return &compose.StreamToolOutput{Result: schema.StreamReaderFromArray([]string{"test"})}, nil } _, _ = middlewares[2].Streamable(streamableEndpoint)(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) enhancedInvokableEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{Result: &schema.ToolResult{}}, nil } _, _ = middlewares[1].EnhancedInvokable(enhancedInvokableEndpoint)(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) enhancedStreamableEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { return &compose.EnhancedStreamableToolOutput{Result: schema.StreamReaderFromArray([]*schema.ToolResult{{}})}, nil } _, _ = middlewares[0].EnhancedStreamable(enhancedStreamableEndpoint)(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.True(t, invokableCalled) assert.True(t, streamableCalled) assert.True(t, enhancedInvokableCalled) assert.True(t, enhancedStreamableCalled) }) t.Run("NoHandlers", func(t *testing.T) { handlers := []ChatModelAgentMiddleware{} middlewares := handlersToToolMiddlewares(handlers) assert.Len(t, middlewares, 0) }) t.Run("HandlerWithNoToolWrappers", func(t *testing.T) { handlers := []ChatModelAgentMiddleware{ &BaseChatModelAgentMiddleware{}, } middlewares := handlersToToolMiddlewares(handlers) assert.Len(t, middlewares, 1) }) t.Run("EnhancedInvokableToolCallErrorPropagation", func(t *testing.T) { expectedErr := errors.New("test error") handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { return nil, expectedErr } }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{Result: &schema.ToolResult{}}, nil } wrapped := middlewares[0].EnhancedInvokable(mockEndpoint) _, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.Error(t, err) assert.Equal(t, expectedErr, err) }) t.Run("EnhancedStreamableToolCallErrorPropagation", func(t *testing.T) { expectedErr := errors.New("test error") handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedStreamableFn: func(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) EnhancedStreamableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { return nil, expectedErr } }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { return &compose.EnhancedStreamableToolOutput{Result: schema.StreamReaderFromArray([]*schema.ToolResult{})}, nil } wrapped := middlewares[0].EnhancedStreamable(mockEndpoint) _, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.Error(t, err) assert.Equal(t, expectedErr, err) }) t.Run("MultipleEnhancedInvokableWrappers", func(t *testing.T) { var executionOrder []string var mu sync.Mutex handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: newTestEnhancedInvokableToolCallWrapper( func() { mu.Lock() executionOrder = append(executionOrder, "handler1-before") mu.Unlock() }, func() { mu.Lock() executionOrder = append(executionOrder, "handler1-after") mu.Unlock() }, ), }, &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: newTestEnhancedInvokableToolCallWrapper( func() { mu.Lock() executionOrder = append(executionOrder, "handler2-before") mu.Unlock() }, func() { mu.Lock() executionOrder = append(executionOrder, "handler2-after") mu.Unlock() }, ), }, } middlewares := handlersToToolMiddlewares(handlers) assert.Len(t, middlewares, 2) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{Result: &schema.ToolResult{}}, nil } wrapped := middlewares[0].EnhancedInvokable(middlewares[1].EnhancedInvokable(mockEndpoint)) _, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.NoError(t, err) assert.Equal(t, []string{"handler2-before", "handler1-before", "handler1-after", "handler2-after"}, executionOrder) }) t.Run("MultipleEnhancedStreamableWrappers", func(t *testing.T) { var executionOrder []string var mu sync.Mutex handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedStreamableFn: newTestEnhancedStreamableToolCallWrapper( func() { mu.Lock() executionOrder = append(executionOrder, "handler1-before") mu.Unlock() }, func() { mu.Lock() executionOrder = append(executionOrder, "handler1-after") mu.Unlock() }, ), }, &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedStreamableFn: newTestEnhancedStreamableToolCallWrapper( func() { mu.Lock() executionOrder = append(executionOrder, "handler2-before") mu.Unlock() }, func() { mu.Lock() executionOrder = append(executionOrder, "handler2-after") mu.Unlock() }, ), }, } middlewares := handlersToToolMiddlewares(handlers) assert.Len(t, middlewares, 2) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { return &compose.EnhancedStreamableToolOutput{Result: schema.StreamReaderFromArray([]*schema.ToolResult{{}})}, nil } wrapped := middlewares[0].EnhancedStreamable(middlewares[1].EnhancedStreamable(mockEndpoint)) _, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.NoError(t, err) assert.Equal(t, []string{"handler2-before", "handler1-before", "handler1-after", "handler2-after"}, executionOrder) }) } func TestEnhancedToolContextPropagation(t *testing.T) { t.Run("ToolContextContainsCorrectInfo", func(t *testing.T) { var capturedCtx *ToolContext handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, tCtx *ToolContext) EnhancedInvokableToolCallEndpoint { capturedCtx = tCtx return endpoint }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{Result: &schema.ToolResult{}}, nil } wrapped := middlewares[0].EnhancedInvokable(mockEndpoint) _, _ = wrapped(context.Background(), &compose.ToolInput{ Name: "my_tool", CallID: "call-123", Arguments: `{"key": "value"}`, }) assert.NotNil(t, capturedCtx) assert.Equal(t, "my_tool", capturedCtx.Name) assert.Equal(t, "call-123", capturedCtx.CallID) }) t.Run("StreamableToolContextContainsCorrectInfo", func(t *testing.T) { var capturedCtx *ToolContext handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedStreamableFn: func(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, tCtx *ToolContext) EnhancedStreamableToolCallEndpoint { capturedCtx = tCtx return endpoint }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { return &compose.EnhancedStreamableToolOutput{Result: schema.StreamReaderFromArray([]*schema.ToolResult{{}})}, nil } wrapped := middlewares[0].EnhancedStreamable(mockEndpoint) _, _ = wrapped(context.Background(), &compose.ToolInput{ Name: "stream_tool", CallID: "call-456", Arguments: `{"data": "test"}`, }) assert.NotNil(t, capturedCtx) assert.Equal(t, "stream_tool", capturedCtx.Name) assert.Equal(t, "call-456", capturedCtx.CallID) }) } func TestBaseChatModelAgentMiddlewareEnhancedDefaults(t *testing.T) { t.Run("DefaultEnhancedInvokableReturnsEndpoint", func(t *testing.T) { base := &BaseChatModelAgentMiddleware{} var called bool endpoint := func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { called = true return &schema.ToolResult{}, nil } wrapped, wrapErr := base.WrapEnhancedInvokableToolCall(context.Background(), endpoint, &ToolContext{Name: "test", CallID: "1"}) assert.NoError(t, wrapErr) _, err := wrapped(context.Background(), &schema.ToolArgument{Text: "{}"}) assert.NoError(t, err) assert.True(t, called) }) t.Run("DefaultEnhancedStreamableReturnsEndpoint", func(t *testing.T) { base := &BaseChatModelAgentMiddleware{} var called bool endpoint := func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { called = true return schema.StreamReaderFromArray([]*schema.ToolResult{}), nil } wrapped, wrapErr := base.WrapEnhancedStreamableToolCall(context.Background(), endpoint, &ToolContext{Name: "test", CallID: "1"}) assert.NoError(t, wrapErr) _, err := wrapped(context.Background(), &schema.ToolArgument{Text: "{}"}) assert.NoError(t, err) assert.True(t, called) }) } func TestEnhancedToolArgumentsPropagation(t *testing.T) { t.Run("ArgumentsPassedCorrectly", func(t *testing.T) { var capturedArgs string handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { capturedArgs = toolArgument.Text return endpoint(ctx, toolArgument, opts...) } }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{Result: &schema.ToolResult{}}, nil } wrapped := middlewares[0].EnhancedInvokable(mockEndpoint) _, _ = wrapped(context.Background(), &compose.ToolInput{ Name: "test_tool", CallID: "call-1", Arguments: `{"name": "test", "value": 123}`, }) assert.Equal(t, `{"name": "test", "value": 123}`, capturedArgs) }) } func TestEnhancedToolResultPropagation(t *testing.T) { t.Run("ResultPassedThroughMiddleware", func(t *testing.T) { expectedResult := &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "original result"}, }, } var capturedResult *schema.ToolResult handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { result, err := endpoint(ctx, toolArgument, opts...) capturedResult = result return result, err } }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{Result: expectedResult}, nil } wrapped := middlewares[0].EnhancedInvokable(mockEndpoint) output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.NoError(t, err) assert.Equal(t, expectedResult, capturedResult) assert.Equal(t, expectedResult, output.Result) }) t.Run("ModifiedResultPropagated", func(t *testing.T) { modifiedResult := &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "modified result"}, }, } handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { _, err := endpoint(ctx, toolArgument, opts...) if err != nil { return nil, err } return modifiedResult, nil } }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return &compose.EnhancedInvokableToolOutput{Result: &schema.ToolResult{ Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "original"}}, }}, nil } wrapped := middlewares[0].EnhancedInvokable(mockEndpoint) output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.NoError(t, err) assert.Equal(t, modifiedResult, output.Result) assert.Equal(t, "modified result", output.Result.Parts[0].Text) }) } func TestEnhancedToolEndpointErrorFromNext(t *testing.T) { t.Run("EnhancedInvokableNextError", func(t *testing.T) { expectedErr := errors.New("next endpoint error") handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedInvokableFn: func(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) EnhancedInvokableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { return endpoint(ctx, toolArgument, opts...) } }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { return nil, expectedErr } wrapped := middlewares[0].EnhancedInvokable(mockEndpoint) _, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.Error(t, err) assert.Equal(t, expectedErr, err) }) t.Run("EnhancedStreamableNextError", func(t *testing.T) { expectedErr := errors.New("next endpoint error") handlers := []ChatModelAgentMiddleware{ &testEnhancedToolWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, wrapEnhancedStreamableFn: func(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) EnhancedStreamableToolCallEndpoint { return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { return endpoint(ctx, toolArgument, opts...) } }, }, } middlewares := handlersToToolMiddlewares(handlers) mockEndpoint := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { return nil, expectedErr } wrapped := middlewares[0].EnhancedStreamable(mockEndpoint) _, err := wrapped(context.Background(), &compose.ToolInput{Name: "test", CallID: "1", Arguments: "{}"}) assert.Error(t, err) assert.Equal(t, expectedErr, err) }) } func TestWrapModelStreamChunksPreserved(t *testing.T) { t.Run("AgentEventMessageStreamShouldPreserveChunksWithNoopWrapModel", func(t *testing.T) { ctx := context.Background() chunk1 := schema.AssistantMessage("Hello ", nil) chunk2 := schema.AssistantMessage("World", nil) mockModel := &mockStreamingModel{ chunks: []*schema.Message{chunk1, chunk2}, } noopWrapModelHandler := &testModelWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, fn: func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return m }, } agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: mockModel, Handlers: []ChatModelAgentMiddleware{noopWrapModelHandler}, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: true, }) iter := r.Run(ctx, []Message{schema.UserMessage("test")}) var streamingEvents []*AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming && event.Output.MessageOutput.Role == schema.Assistant { streamingEvents = append(streamingEvents, event) } } assert.GreaterOrEqual(t, len(streamingEvents), 1, "Should have at least one streaming event") if len(streamingEvents) > 0 { event := streamingEvents[0] assert.NotNil(t, event.Output.MessageOutput.MessageStream, "Event should have message stream") var receivedChunks []*schema.Message for { chunk, recvErr := event.Output.MessageOutput.MessageStream.Recv() if recvErr != nil { break } receivedChunks = append(receivedChunks, chunk) } assert.Equal(t, 2, len(receivedChunks), "AgentEvent's MessageStream should contain 2 separate chunks, not 1 concatenated chunk. "+ "Got %d chunks instead. This indicates the stream is being concatenated before being sent to AgentEvent.", len(receivedChunks)) if len(receivedChunks) >= 2 { assert.Equal(t, "Hello ", receivedChunks[0].Content, "First chunk content should be preserved") assert.Equal(t, "World", receivedChunks[1].Content, "Second chunk content should be preserved") } } }) t.Run("AgentEventMessageStreamShouldReflectUserMiddlewareModifications", func(t *testing.T) { ctx := context.Background() chunk1 := schema.AssistantMessage("Hello ", nil) chunk2 := schema.AssistantMessage("World", nil) mockModel := &mockStreamingModel{ chunks: []*schema.Message{chunk1, chunk2}, } streamConsumingWrapModelHandler := &testModelWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, fn: func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return &streamConsumingModelWrapper{inner: m} }, } agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: mockModel, Handlers: []ChatModelAgentMiddleware{streamConsumingWrapModelHandler}, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: true, }) iter := r.Run(ctx, []Message{schema.UserMessage("test")}) var streamingEvents []*AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming && event.Output.MessageOutput.Role == schema.Assistant { streamingEvents = append(streamingEvents, event) } } assert.GreaterOrEqual(t, len(streamingEvents), 1, "Should have at least one streaming event") if len(streamingEvents) > 0 { event := streamingEvents[0] assert.NotNil(t, event.Output.MessageOutput.MessageStream, "Event should have message stream") var receivedChunks []*schema.Message for { chunk, recvErr := event.Output.MessageOutput.MessageStream.Recv() if recvErr != nil { break } receivedChunks = append(receivedChunks, chunk) } assert.Equal(t, 1, len(receivedChunks), "AgentEvent's MessageStream should contain 1 concatenated chunk (modified by user middleware). "+ "Got %d chunks instead.", len(receivedChunks)) if len(receivedChunks) >= 1 { assert.Equal(t, "Hello World", receivedChunks[0].Content, "Chunk content should be concatenated by user middleware") } } }) t.Run("AgentEventMessageStreamShouldReflectMultipleUserMiddlewareModifications", func(t *testing.T) { ctx := context.Background() chunk1 := schema.AssistantMessage("Hello ", nil) chunk2 := schema.AssistantMessage("World", nil) mockModel := &mockStreamingModel{ chunks: []*schema.Message{chunk1, chunk2}, } handler1 := &testModelWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, fn: func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return &streamConsumingModelWrapper{inner: m} }, } handler2 := &testModelWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, fn: func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return m }, } agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: mockModel, Handlers: []ChatModelAgentMiddleware{handler1, handler2}, ModelRetryConfig: &ModelRetryConfig{ MaxRetries: 3, }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: true, }) iter := r.Run(ctx, []Message{schema.UserMessage("test")}) var streamingEvents []*AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming && event.Output.MessageOutput.Role == schema.Assistant { streamingEvents = append(streamingEvents, event) } } assert.GreaterOrEqual(t, len(streamingEvents), 1, "Should have at least one streaming event") if len(streamingEvents) > 0 { event := streamingEvents[0] assert.NotNil(t, event.Output.MessageOutput.MessageStream, "Event should have message stream") var receivedChunks []*schema.Message for { chunk, recvErr := event.Output.MessageOutput.MessageStream.Recv() if recvErr != nil { break } receivedChunks = append(receivedChunks, chunk) } assert.Equal(t, 1, len(receivedChunks), "AgentEvent's MessageStream should contain 1 concatenated chunk (modified by user middleware). "+ "Got %d chunks instead.", len(receivedChunks)) if len(receivedChunks) >= 1 { assert.Equal(t, "Hello World", receivedChunks[0].Content, "Chunk content should be concatenated by user middleware") } } }) } type mockStreamingModel struct { chunks []*schema.Message } func (m *mockStreamingModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { return schema.ConcatMessages(m.chunks) } func (m *mockStreamingModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](len(m.chunks)) go func() { defer sw.Close() for _, chunk := range m.chunks { sw.Send(chunk, nil) } }() return sr, nil } type streamConsumingModelWrapper struct { inner model.BaseChatModel } func (m *streamConsumingModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { return m.inner.Generate(ctx, input, opts...) } func (m *streamConsumingModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { stream, err := m.inner.Stream(ctx, input, opts...) if err != nil { return nil, err } result, err := schema.ConcatMessageStream(stream) if err != nil { return nil, err } return schema.StreamReaderFromArray([]*schema.Message{result}), nil } func TestEventSenderModelWrapperCustomPosition(t *testing.T) { t.Run("UserConfiguredEventSenderSkipsDefaultEventSender", func(t *testing.T) { ctx := context.Background() chunk1 := schema.AssistantMessage("Hello ", nil) chunk2 := schema.AssistantMessage("World", nil) mockModel := &mockStreamingModel{ chunks: []*schema.Message{chunk1, chunk2}, } agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: mockModel, Handlers: []ChatModelAgentMiddleware{NewEventSenderModelWrapper()}, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: true, }) iter := r.Run(ctx, []Message{schema.UserMessage("test")}) var streamingEvents []*AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming && event.Output.MessageOutput.Role == schema.Assistant { streamingEvents = append(streamingEvents, event) } } assert.Equal(t, 1, len(streamingEvents), "Should have exactly one streaming event (no duplicate from default event sender)") }) t.Run("EventSenderAfterUserMiddlewareByDefault", func(t *testing.T) { ctx := context.Background() mockModel := &mockStreamingModel{ chunks: []*schema.Message{ schema.AssistantMessage("Original", nil), }, } modifiedContent := "Modified" contentModifyingHandler := &testModelWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, fn: func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return &contentModifyingModelWrapper{inner: m, newContent: modifiedContent} }, } agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: mockModel, Handlers: []ChatModelAgentMiddleware{contentModifyingHandler}, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: false, }) iter := r.Run(ctx, []Message{schema.UserMessage("test")}) var assistantEvents []*AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.Role == schema.Assistant { assistantEvents = append(assistantEvents, event) } } assert.GreaterOrEqual(t, len(assistantEvents), 1, "Should have at least one assistant event") if len(assistantEvents) > 0 { msg := assistantEvents[0].Output.MessageOutput.Message assert.Equal(t, modifiedContent, msg.Content, "Event should contain modified content from user middleware") } }) t.Run("EventSenderInnermostGetsOriginalOutput", func(t *testing.T) { ctx := context.Background() originalContent := "Original" mockModel := &mockStreamingModel{ chunks: []*schema.Message{ schema.AssistantMessage(originalContent, nil), }, } modifiedContent := "Modified" contentModifyingHandler := &testModelWrapperHandler{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, fn: func(_ context.Context, m model.BaseChatModel, _ *ModelContext) model.BaseChatModel { return &contentModifyingModelWrapper{inner: m, newContent: modifiedContent} }, } agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", Description: "Test agent", Model: mockModel, Handlers: []ChatModelAgentMiddleware{ contentModifyingHandler, NewEventSenderModelWrapper(), }, }) assert.NoError(t, err) r := NewRunner(ctx, RunnerConfig{ Agent: agent, EnableStreaming: false, }) iter := r.Run(ctx, []Message{schema.UserMessage("test")}) var assistantEvents []*AgentEvent for { event, ok := iter.Next() if !ok { break } if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.Role == schema.Assistant { assistantEvents = append(assistantEvents, event) } } assert.GreaterOrEqual(t, len(assistantEvents), 1, "Should have at least one assistant event") if len(assistantEvents) > 0 { msg := assistantEvents[0].Output.MessageOutput.Message assert.Equal(t, originalContent, msg.Content, "Event should contain original content (EventSenderModelWrapper is innermost)") } }) } type contentModifyingModelWrapper struct { inner model.BaseChatModel newContent string } func (m *contentModifyingModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { result, err := m.inner.Generate(ctx, input, opts...) if err != nil { return nil, err } result.Content = m.newContent return result, nil } func (m *contentModifyingModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { stream, err := m.inner.Stream(ctx, input, opts...) if err != nil { return nil, err } result, err := schema.ConcatMessageStream(stream) if err != nil { return nil, err } result.Content = m.newContent return schema.StreamReaderFromArray([]*schema.Message{result}), nil } ================================================ FILE: callbacks/aspect_inject.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "context" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/callbacks" "github.com/cloudwego/eino/schema" ) // OnStart Fast inject callback input / output aspect for component developer // e.g. // // func (t *testChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (resp *schema.Message, err error) { // defer func() { // if err != nil { // callbacks.OnError(ctx, err) // } // }() // // ctx = callbacks.OnStart(ctx, &model.CallbackInput{ // Messages: input, // Tools: nil, // Extra: nil, // }) // // // do smt // // ctx = callbacks.OnEnd(ctx, &model.CallbackOutput{ // Message: resp, // Extra: nil, // }) // // return resp, nil // } // OnStart invokes the OnStart timing for all registered handlers in the // context. This is called by component implementations that manage their own // callbacks (i.e. implement [components.Checker] and return true from // IsCallbacksEnabled). The returned context must be propagated to subsequent // OnEnd/OnError calls so handlers can correlate start and end events. // // Handlers are invoked in reverse registration order (last registered = first // called) to match the middleware wrapping convention. // // Example — typical usage inside a component's Generate method: // // func (m *myChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { // ctx = callbacks.OnStart(ctx, &model.CallbackInput{Messages: input}) // resp, err := m.doGenerate(ctx, input, opts...) // if err != nil { // callbacks.OnError(ctx, err) // return nil, err // } // callbacks.OnEnd(ctx, &model.CallbackOutput{Message: resp}) // return resp, nil // } func OnStart[T any](ctx context.Context, input T) context.Context { ctx, _ = callbacks.On(ctx, input, callbacks.OnStartHandle[T], TimingOnStart, true) return ctx } // OnEnd invokes the OnEnd timing for all registered handlers. Call this after // the component produces a successful result. Handlers run in registration // order (first registered = first called). // // Do not call both OnEnd and OnError for the same invocation — OnEnd signals // success; OnError signals failure. func OnEnd[T any](ctx context.Context, output T) context.Context { ctx, _ = callbacks.On(ctx, output, callbacks.OnEndHandle[T], TimingOnEnd, false) return ctx } // OnStartWithStreamInput invokes the OnStartWithStreamInput timing. Use this // when the component's input is itself a stream (Collect / Transform // paradigms). The framework automatically copies the stream so each handler // receives an independent reader; handlers MUST close their copy or the // underlying goroutine will leak. // // Returns the updated context and a new StreamReader that the component should // use going forward (the original is consumed by the framework). func OnStartWithStreamInput[T any](ctx context.Context, input *schema.StreamReader[T]) ( nextCtx context.Context, newStreamReader *schema.StreamReader[T]) { return callbacks.On(ctx, input, callbacks.OnStartWithStreamInputHandle[T], TimingOnStartWithStreamInput, true) } // OnEndWithStreamOutput invokes the OnEndWithStreamOutput timing. Use this // when the component produces a streaming output (Stream / Transform // paradigms). Like OnStartWithStreamInput, stream copies are made per // handler; each handler must close its copy. // // Returns the updated context and the StreamReader the component should return // to its caller. func OnEndWithStreamOutput[T any](ctx context.Context, output *schema.StreamReader[T]) ( nextCtx context.Context, newStreamReader *schema.StreamReader[T]) { return callbacks.On(ctx, output, callbacks.OnEndWithStreamOutputHandle[T], TimingOnEndWithStreamOutput, false) } // OnError invokes the OnError timing for all registered handlers. Call this // when the component returns an error. Errors that occur mid-stream (after the // StreamReader has been returned) are NOT routed through OnError; they surface // as errors inside Recv. // // Handlers run in registration order (same as OnEnd). func OnError(ctx context.Context, err error) context.Context { ctx, _ = callbacks.On(ctx, err, callbacks.OnErrorHandle, TimingOnError, false) return ctx } // EnsureRunInfo ensures the context carries a [RunInfo] for the given type and // component kind. If the context already has a matching RunInfo, it is // returned unchanged. Otherwise, a new callback manager is created that // inherits the global handlers plus any handlers already in ctx. // // Component implementations that set IsCallbacksEnabled() = true should call // this at the start of every public method (Generate, Stream, etc.) before // calling [OnStart], so that the RunInfo is never missing from callbacks. func EnsureRunInfo(ctx context.Context, typ string, comp components.Component) context.Context { return callbacks.EnsureRunInfo(ctx, typ, comp) } // ReuseHandlers creates a new context that inherits all handlers already // present in ctx and sets a new RunInfo. Global handlers are added if ctx // carries none yet. // // Use this when a component calls another component internally and wants the // inner component's callbacks to share the same set of handlers as the outer // component, but with the inner component's own identity in RunInfo: // // innerCtx := callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{ // Type: "InnerChatModel", // Component: components.ComponentOfChatModel, // Name: "inner-cm", // }) func ReuseHandlers(ctx context.Context, info *RunInfo) context.Context { return callbacks.ReuseHandlers(ctx, info) } // InitCallbacks creates a new context with the given RunInfo and handlers, // completely replacing any RunInfo and handlers already in ctx. // // Use this when running a component standalone outside a Graph — the Graph // normally manages RunInfo injection automatically, but standalone callers must // set it up themselves: // // ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ // Type: myModel.GetType(), // Component: components.ComponentOfChatModel, // Name: "my-model", // }, myHandler) func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) context.Context { return callbacks.InitCallbacks(ctx, info, handlers...) } ================================================ FILE: callbacks/aspect_inject_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "context" "fmt" "io" "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/internal/callbacks" "github.com/cloudwego/eino/schema" ) func TestAspectInject(t *testing.T) { t.Run("ctx without manager", func(t *testing.T) { ctx := context.Background() ctx = OnStart(ctx, 1) ctx = OnEnd(ctx, 2) ctx = OnError(ctx, fmt.Errorf("3")) isr, isw := schema.Pipe[int](2) go func() { for i := 0; i < 10; i++ { isw.Send(i, nil) } isw.Close() }() var nisr *schema.StreamReader[int] ctx, nisr = OnStartWithStreamInput(ctx, isr) j := 0 for { i, err := nisr.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, j, i) j++ } nisr.Close() osr, osw := schema.Pipe[int](2) go func() { for i := 0; i < 10; i++ { osw.Send(i, nil) } osw.Close() }() var nosr *schema.StreamReader[int] ctx, nosr = OnEndWithStreamOutput(ctx, osr) j = 0 for { i, err := nosr.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, j, i) j++ } nosr.Close() }) t.Run("ctx with manager", func(t *testing.T) { ctx := context.Background() cnt := 0 hb := NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { cnt += input.(int) return ctx }). OnEndFn(func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { cnt += output.(int) return ctx }). OnErrorFn(func(ctx context.Context, info *RunInfo, err error) context.Context { v, _ := strconv.ParseInt(err.Error(), 10, 64) cnt += int(v) return ctx }). OnStartWithStreamInputFn(func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { for { i, err := input.Recv() if err == io.EOF { break } cnt += i.(int) } input.Close() return ctx }). OnEndWithStreamOutputFn(func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { for { o, err := output.Recv() if err == io.EOF { break } cnt += o.(int) } output.Close() return ctx }).Build() ctx = InitCallbacks(ctx, nil, hb) ctx = OnStart(ctx, 1) ctx = OnEnd(ctx, 2) ctx = OnError(ctx, fmt.Errorf("3")) isr, isw := schema.Pipe[int](2) go func() { for i := 0; i < 10; i++ { isw.Send(i, nil) } isw.Close() }() ctx = ReuseHandlers(ctx, &RunInfo{}) var nisr *schema.StreamReader[int] ctx, nisr = OnStartWithStreamInput(ctx, isr) j := 0 for { i, err := nisr.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, j, i) j++ cnt += i } nisr.Close() osr, osw := schema.Pipe[int](2) go func() { for i := 0; i < 10; i++ { osw.Send(i, nil) } osw.Close() }() var nosr *schema.StreamReader[int] ctx, nosr = OnEndWithStreamOutput(ctx, osr) j = 0 for { i, err := nosr.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, j, i) j++ cnt += i } nosr.Close() assert.Equal(t, 186, cnt) }) } func TestGlobalCallbacksRepeated(t *testing.T) { times := 0 testHandler := NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { times++ return ctx }).Build() callbacks.GlobalHandlers = append(callbacks.GlobalHandlers, testHandler) ctx := context.Background() ctx = callbacks.AppendHandlers(ctx, &RunInfo{}) ctx = callbacks.AppendHandlers(ctx, &RunInfo{}) callbacks.On(ctx, "test", callbacks.OnStartHandle[string], TimingOnStart, true) assert.Equal(t, times, 1) } func TestEnsureRunInfo(t *testing.T) { ctx := context.Background() var name, typ, comp string ctx = InitCallbacks(ctx, &RunInfo{Name: "name", Type: "type", Component: "component"}, NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { name = info.Name typ = info.Type comp = string(info.Component) return ctx }).Build()) ctx = OnStart(ctx, "") assert.Equal(t, "name", name) assert.Equal(t, "type", typ) assert.Equal(t, "component", comp) ctx2 := EnsureRunInfo(ctx, "type2", "component2") OnStart(ctx2, "") assert.Equal(t, "", name) assert.Equal(t, "type2", typ) assert.Equal(t, "component2", comp) // EnsureRunInfo on an empty Context AppendGlobalHandlers(NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { typ = info.Type comp = string(info.Component) return ctx }).Build()) ctx3 := EnsureRunInfo(context.Background(), "type3", "component3") OnStart(ctx3, 0) assert.Equal(t, "type3", typ) assert.Equal(t, "component3", comp) callbacks.GlobalHandlers = []Handler{} } func TestNesting(t *testing.T) { ctx := context.Background() cb := &myCallback{t: t} ctx = InitCallbacks(ctx, &RunInfo{ Name: "test", }, cb) // jumped ctx1 := OnStart(ctx, 0) ctx2 := OnStart(ctx1, 1) OnEnd(ctx2, 1) OnEnd(ctx1, 0) assert.Equal(t, 4, cb.times) // reused cb.times = 0 ctx1 = OnStart(ctx, 0) ctx2 = ReuseHandlers(ctx1, &RunInfo{Name: "test2"}) ctx3 := OnStart(ctx2, 1) OnEnd(ctx3, 1) OnEnd(ctx1, 0) assert.Equal(t, 4, cb.times) } func TestReuseHandlersOnEmptyCtx(t *testing.T) { callbacks.GlobalHandlers = []Handler{} cb := &myCallback{t: t} AppendGlobalHandlers(cb) ctx := ReuseHandlers(context.Background(), &RunInfo{Name: "test"}) OnStart(ctx, 0) assert.Equal(t, 1, cb.times) } func TestAppendHandlersTwiceOnSameCtx(t *testing.T) { callbacks.GlobalHandlers = []Handler{} cb := &myCallback{t: t} cb1 := &myCallback{t: t} cb2 := &myCallback{t: t} ctx := InitCallbacks(context.Background(), &RunInfo{Name: "test"}, cb) ctx1 := callbacks.AppendHandlers(ctx, &RunInfo{Name: "test"}, cb1) ctx2 := callbacks.AppendHandlers(ctx, &RunInfo{Name: "test"}, cb2) OnStart(ctx1, 0) OnStart(ctx2, 0) assert.Equal(t, 2, cb.times) assert.Equal(t, 1, cb1.times) assert.Equal(t, 1, cb2.times) } type myCallback struct { t *testing.T times int } func (m *myCallback) OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { m.times++ if info == nil { assert.Equal(m.t, 2, m.times) return ctx } if info.Name == "test" { assert.Equal(m.t, 0, input) } else { assert.Equal(m.t, 1, input) } return ctx } func (m *myCallback) OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { m.times++ if info == nil { assert.Equal(m.t, 3, m.times) return ctx } if info.Name == "test" { assert.Equal(m.t, 0, output) } else { assert.Equal(m.t, 1, output) } return ctx } func (m *myCallback) OnError(ctx context.Context, info *RunInfo, err error) context.Context { panic("implement me") } func (m *myCallback) OnStartWithStreamInput(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { panic("implement me") } func (m *myCallback) OnEndWithStreamOutput(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { panic("implement me") } ================================================ FILE: callbacks/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package callbacks provides observability hooks for component execution in Eino. // // Callbacks fire at five lifecycle timings around every component invocation: // - [TimingOnStart] / [TimingOnEnd]: non-streaming input and output. // - [TimingOnStartWithStreamInput] / [TimingOnEndWithStreamOutput]: streaming // variants — handlers receive a copy of the stream and MUST close it. // - [TimingOnError]: component returned a non-nil error (stream-internal // errors are NOT reported here). // // # Attaching Handlers // // Global handlers (observe every node in every graph): // // callbacks.AppendGlobalHandlers(myHandler) // call once, at startup — NOT thread-safe // // Per-invocation handlers (observe one graph run): // // runnable.Invoke(ctx, input, compose.WithCallbacks(myHandler)) // // Target a specific node: // // compose.WithCallbacks(myHandler).DesignateNode("nodeName") // // Handler inheritance: if the context passed to a graph run already carries // handlers (e.g. from a parent graph), those handlers are inherited by the // entire child run automatically. // // # Building Handlers // // Option 1 — [NewHandlerBuilder]: register raw functions for the timings you // need. Input/output are untyped; use the component package's ConvCallbackInput // helper to cast to a concrete type: // // handler := callbacks.NewHandlerBuilder(). // OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { // // Handle component start // return ctx // }). // OnEndFn(func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { // // Handle component end // return ctx // }). // OnErrorFn(func(ctx context.Context, info *RunInfo, err error) context.Context { // // Handle component error // return ctx // }). // OnStartWithStreamInputFn(func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { // defer input.Close() // MUST close — failure causes pipeline goroutine leak // return ctx // }). // OnEndWithStreamOutputFn(func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { // defer output.Close() // MUST close // return ctx // }). // Build() // // Option 2 — utils/callbacks.NewHandlerHelper: dispatches by component type, so // each handler function receives the concrete typed input/output directly: // // handler := callbacks.NewHandlerHelper(). // ChatModel(&model.CallbackHandler{ // OnStart: func(ctx context.Context, info *RunInfo, input *model.CallbackInput) context.Context { // log.Printf("Model started: %s, messages: %d", info.Name, len(input.Messages)) // return ctx // }, // }). // Prompt(&prompt.CallbackHandler{ // OnEnd: func(ctx context.Context, info *RunInfo, output *prompt.CallbackOutput) context.Context { // log.Printf("Prompt completed") // return ctx // }, // }). // Handler() // // # Passing State Within a Handler // // The ctx returned by one timing is passed to the next timing of the SAME // handler, enabling OnStart→OnEnd state transfer via context.WithValue: // // NewHandlerBuilder(). // OnStartFn(func(ctx context.Context, info *RunInfo, _ CallbackInput) context.Context { // return context.WithValue(ctx, startTimeKey{}, time.Now()) // }). // OnEndFn(func(ctx context.Context, info *RunInfo, _ CallbackOutput) context.Context { // start := ctx.Value(startTimeKey{}).(time.Time) // log.Printf("duration: %v", time.Since(start)) // return ctx // }).Build() // // Between DIFFERENT handlers there is no guaranteed execution order and no // context chain. To share state between handlers, store it in a // concurrency-safe variable in the outermost context instead. // // # Common Pitfalls // // - Stream copies must be closed: when N handlers register for a streaming // timing, the stream is copied N+1 times (one per handler + one for // downstream). If any handler's copy is not closed, the original stream // cannot be freed and the entire pipeline leaks. // // - Do NOT mutate Input/Output: all downstream nodes and handlers share the // same pointer. Mutations cause data races in concurrent graph execution. // // - AppendGlobalHandlers is NOT thread-safe: call only during initialization, // never concurrently with graph execution. // // - Stream errors are invisible to OnError: errors that occur while a // consumer reads from a StreamReader are not routed through OnError. // // - RunInfo may be nil: always nil-check before dereferencing in handlers, // especially when a component is used standalone outside a graph without // InitCallbacks being called. package callbacks ================================================ FILE: callbacks/handler_builder.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "context" "github.com/cloudwego/eino/schema" ) // HandlerBuilder constructs a [Handler] by registering callback functions for // individual timings. Only set the timings you care about; the built handler // implements [TimingChecker] and returns false for unregistered timings, so // the framework skips those timings with no overhead. // // The input/output values are untyped (CallbackInput / CallbackOutput). To // work with a specific component's payload, use the component package's // ConvCallbackInput / ConvCallbackOutput helpers inside your function. For a // higher-level API that dispatches by component type automatically, see // utils/callbacks.NewHandlerHelper. // // Example: // // handler := callbacks.NewHandlerBuilder(). // OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { // mi := model.ConvCallbackInput(input) // if mi != nil { // log.Printf("[%s] model start: %d messages", info.Name, len(mi.Messages)) // } // return ctx // }). // OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { // mo := model.ConvCallbackOutput(output) // if mo != nil && mo.Message.ResponseMeta != nil { // log.Printf("[%s] tokens: %d", info.Name, mo.Message.ResponseMeta.Usage.TotalTokens) // } // return ctx // }). // Build() type HandlerBuilder struct { onStartFn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context onEndFn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context onErrorFn func(ctx context.Context, info *RunInfo, err error) context.Context onStartWithStreamInputFn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context onEndWithStreamOutputFn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context } type handlerImpl struct { HandlerBuilder } func (hb *handlerImpl) OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { return hb.onStartFn(ctx, info, input) } func (hb *handlerImpl) OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { return hb.onEndFn(ctx, info, output) } func (hb *handlerImpl) OnError(ctx context.Context, info *RunInfo, err error) context.Context { return hb.onErrorFn(ctx, info, err) } func (hb *handlerImpl) OnStartWithStreamInput(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { return hb.onStartWithStreamInputFn(ctx, info, input) } func (hb *handlerImpl) OnEndWithStreamOutput(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { return hb.onEndWithStreamOutputFn(ctx, info, output) } func (hb *handlerImpl) Needed(_ context.Context, _ *RunInfo, timing CallbackTiming) bool { switch timing { case TimingOnStart: return hb.onStartFn != nil case TimingOnEnd: return hb.onEndFn != nil case TimingOnError: return hb.onErrorFn != nil case TimingOnStartWithStreamInput: return hb.onStartWithStreamInputFn != nil case TimingOnEndWithStreamOutput: return hb.onEndWithStreamOutputFn != nil default: return false } } // NewHandlerBuilder creates and returns a new HandlerBuilder instance. // HandlerBuilder is used to construct a Handler with custom callback functions func NewHandlerBuilder() *HandlerBuilder { return &HandlerBuilder{} } // OnStartFn sets the handler for the start timing. func (hb *HandlerBuilder) OnStartFn( fn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context) *HandlerBuilder { hb.onStartFn = fn return hb } // OnEndFn sets the handler for the end timing. func (hb *HandlerBuilder) OnEndFn( fn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context) *HandlerBuilder { hb.onEndFn = fn return hb } // OnErrorFn sets the handler for the error timing. func (hb *HandlerBuilder) OnErrorFn( fn func(ctx context.Context, info *RunInfo, err error) context.Context) *HandlerBuilder { hb.onErrorFn = fn return hb } // OnStartWithStreamInputFn sets the callback invoked when a component receives // streaming input. The handler receives a [*schema.StreamReader] that is a // private copy; it MUST close the reader after consuming it to avoid goroutine // and memory leaks. func (hb *HandlerBuilder) OnStartWithStreamInputFn( fn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context) *HandlerBuilder { hb.onStartWithStreamInputFn = fn return hb } // OnEndWithStreamOutputFn sets the callback invoked when a component produces // streaming output. Like OnStartWithStreamInputFn, the handler receives a // private copy of the stream and MUST close it after reading to prevent // goroutine and memory leaks. This is the right place to implement streaming // token-usage accounting or streaming log capture. func (hb *HandlerBuilder) OnEndWithStreamOutputFn( fn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context) *HandlerBuilder { hb.onEndWithStreamOutputFn = fn return hb } // Build returns a Handler with the functions set in the builder. func (hb *HandlerBuilder) Build() Handler { return &handlerImpl{*hb} } ================================================ FILE: callbacks/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "github.com/cloudwego/eino/internal/callbacks" ) // RunInfo describes the entity that triggered a callback. Always nil-check // before dereferencing — a component that calls OnStart without first calling // EnsureRunInfo or InitCallbacks will leave RunInfo absent in the context. // // Fields: // - Name: business-meaningful name specified by the user. For nodes in a // graph this is the node name (compose.WithNodeName). For standalone // components it must be set explicitly via [InitCallbacks] or // [ReuseHandlers]; it is empty string if not set. // - Type: implementation identity, e.g. "OpenAI". Set by the component via // [components.Typer]; falls back to reflection (struct/func name) if the // interface is not implemented. Empty for Graph itself. // - Component: category constant, e.g. components.ComponentOfChatModel. // Fixed value "Lambda" for lambdas, "Graph"/"Chain"/"Workflow" for graphs. // Use this to branch on component kind without caring about implementation. // // Handlers should filter using RunInfo rather than assuming a fixed execution // order — there is no guaranteed ordering between different Handlers. type RunInfo = callbacks.RunInfo // CallbackInput is the value passed to OnStart and OnStartWithStreamInput // handlers. The concrete type is defined by the component — for example, // ChatModel callbacks carry *model.CallbackInput. Use the component package's // ConvCallbackInput helper (e.g. model.ConvCallbackInput) to cast safely; it // returns nil if the type does not match, so you can ignore irrelevant // component types: // // modelInput := model.ConvCallbackInput(in) // if modelInput == nil { // return ctx // not a model invocation, skip // } // log.Printf("prompt: %v", modelInput.Messages) type CallbackInput = callbacks.CallbackInput // CallbackOutput is the value passed to OnEnd and OnEndWithStreamOutput // handlers. Like CallbackInput, the concrete type is component-defined. // Use the component package's ConvCallbackOutput helper to cast safely. type CallbackOutput = callbacks.CallbackOutput // Handler is the unified callback handler interface. Implement all five // methods (OnStart, OnEnd, OnError, OnStartWithStreamInput, // OnEndWithStreamOutput) or use [NewHandlerBuilder] to set only the timings // you care about. // // Each method receives the context returned by the previous timing of the // SAME handler, which lets a single handler pass state between its OnStart // and OnEnd calls via context.WithValue. There is NO guaranteed execution // order between DIFFERENT handlers, and the context chain does not flow // from one handler to the next — do not rely on handler ordering. // // Implement [TimingChecker] (the Needed method) on your handler so the // framework can skip timings you have not registered; this avoids unnecessary // stream copies and goroutine allocations on every component invocation. // // Stream handlers (OnStartWithStreamInput, OnEndWithStreamOutput) receive a // [*schema.StreamReader] that has already been copied; they MUST close their // copy after reading. If any handler's copy is not closed, the original stream // cannot be freed, causing a goroutine/memory leak for the entire pipeline. // // Important: do NOT mutate the Input or Output values. All downstream nodes // and handlers share the same pointer (direct assignment, not a deep copy). // Mutations cause data races in concurrent graph execution. type Handler = callbacks.Handler // InitCallbackHandlers sets the global callback handlers. // It should be called BEFORE any callback handler by user. // It's useful when you want to inject some basic callbacks to all nodes. // Deprecated: Use AppendGlobalHandlers instead. func InitCallbackHandlers(handlers []Handler) { callbacks.GlobalHandlers = handlers } // AppendGlobalHandlers appends handlers to the process-wide list of callback // handlers. Global handlers run before per-invocation handlers provided via // compose.WithCallbacks, giving them higher priority for instrumentation that // must observe every component invocation (e.g. distributed tracing, metrics). // // This function is NOT thread-safe. Call it once during program initialization // (e.g. in main or TestMain), before any graph executions begin. // Calling it concurrently with ongoing graph executions leads to data races. func AppendGlobalHandlers(handlers ...Handler) { callbacks.GlobalHandlers = append(callbacks.GlobalHandlers, handlers...) } // CallbackTiming enumerates the lifecycle moments at which a callback handler // is invoked. Implement [TimingChecker] on your handler and return false for // timings you do not handle, so the framework skips the overhead of stream // copying and goroutine spawning for those timings. type CallbackTiming = callbacks.CallbackTiming // Callback timing constants. const ( // TimingOnStart fires just before the component begins processing. // Receives a fully-formed input value (non-streaming). TimingOnStart CallbackTiming = iota // TimingOnEnd fires after the component returns a result successfully. // Receives the output value. Only fires on success — not on error. TimingOnEnd // TimingOnError fires when the component returns a non-nil error. // Stream errors (mid-stream panics) are NOT reported here; they surface // as errors inside the stream reader. TimingOnError // TimingOnStartWithStreamInput fires when the component receives a // streaming input (Collect / Transform paradigms). The handler receives a // copy of the input stream and must close it after reading. TimingOnStartWithStreamInput // TimingOnEndWithStreamOutput fires after the component returns a // streaming output (Stream / Transform paradigms). The handler receives a // copy of the output stream and must close it after reading. This is // typically where you implement streaming metrics or logging. TimingOnEndWithStreamOutput ) // TimingChecker is an optional interface for [Handler] implementations. // When a handler implements Needed, the framework calls it before each // component invocation to decide whether to set up callback infrastructure // (stream copying, goroutine allocation) for that timing. Returning false // avoids unnecessary overhead. // // Handlers built with [NewHandlerBuilder] or // utils/callbacks.NewHandlerHelper automatically implement TimingChecker // based on which callback functions were set. type TimingChecker = callbacks.TimingChecker ================================================ FILE: callbacks/interface_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/internal/callbacks" ) func TestAppendGlobalHandlers(t *testing.T) { // Clear global handlers before test callbacks.GlobalHandlers = nil // Create test handlers handler1 := NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { return ctx }).Build() handler2 := NewHandlerBuilder(). OnEndFn(func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { return ctx }).Build() // Test appending first handler AppendGlobalHandlers(handler1) assert.Equal(t, 1, len(callbacks.GlobalHandlers)) assert.Contains(t, callbacks.GlobalHandlers, handler1) // Test appending second handler AppendGlobalHandlers(handler2) assert.Equal(t, 2, len(callbacks.GlobalHandlers)) assert.Contains(t, callbacks.GlobalHandlers, handler1) assert.Contains(t, callbacks.GlobalHandlers, handler2) // Test appending nil handler AppendGlobalHandlers([]Handler{}...) assert.Equal(t, 2, len(callbacks.GlobalHandlers)) } ================================================ FILE: components/document/callback_extra_loader.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package document import ( "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // LoaderCallbackInput is the input for the loader callback. type LoaderCallbackInput struct { // Source is the source of the documents. Source Source // Extra is the extra information for the callback. Extra map[string]any } // LoaderCallbackOutput is the output for the loader callback. type LoaderCallbackOutput struct { // Source is the source of the documents. Source Source // Docs is the documents to be loaded. Docs []*schema.Document // Extra is the extra information for the callback. Extra map[string]any } // ConvLoaderCallbackInput converts the callback input to the loader callback input. func ConvLoaderCallbackInput(src callbacks.CallbackInput) *LoaderCallbackInput { switch t := src.(type) { case *LoaderCallbackInput: return t case Source: return &LoaderCallbackInput{ Source: t, } default: return nil } } // ConvLoaderCallbackOutput converts the callback output to the loader callback output. func ConvLoaderCallbackOutput(src callbacks.CallbackOutput) *LoaderCallbackOutput { switch t := src.(type) { case *LoaderCallbackOutput: return t case []*schema.Document: return &LoaderCallbackOutput{ Docs: t, } default: return nil } } ================================================ FILE: components/document/callback_extra_transformer.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package document import ( "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // TransformerCallbackInput is the input for the transformer callback. type TransformerCallbackInput struct { // Input is the input documents. Input []*schema.Document // Extra is the extra information for the callback. Extra map[string]any } // TransformerCallbackOutput is the output for the transformer callback. type TransformerCallbackOutput struct { // Output is the output documents. Output []*schema.Document // Extra is the extra information for the callback. Extra map[string]any } // ConvTransformerCallbackInput converts the callback input to the transformer callback input. func ConvTransformerCallbackInput(src callbacks.CallbackInput) *TransformerCallbackInput { switch t := src.(type) { case *TransformerCallbackInput: return t case []*schema.Document: return &TransformerCallbackInput{ Input: t, } default: return nil } } // ConvTransformerCallbackOutput converts the callback output to the transformer callback output. func ConvTransformerCallbackOutput(src callbacks.CallbackOutput) *TransformerCallbackOutput { switch t := src.(type) { case *TransformerCallbackOutput: return t case []*schema.Document: return &TransformerCallbackOutput{ Output: t, } default: return nil } } ================================================ FILE: components/document/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package document defines the Loader and Transformer component interfaces // for ingesting and processing documents in an eino pipeline. // // # Components // // - [Loader]: reads raw content from an external source (file, URL, S3, …) // and returns [schema.Document] values. Parsing is typically delegated to // a [parser.Parser] configured on the loader. // - [Transformer]: takes a slice of [schema.Document] values and transforms // them — splitting, filtering, merging, re-ranking, etc. // // Concrete implementations live in eino-ext: // // github.com/cloudwego/eino-ext/components/document/ // // # Document Metadata // // [schema.Document].MetaData is the primary mechanism for carrying contextual // information (source URI, scores, chunk indices, embeddings) through the // pipeline. Transformers should preserve existing metadata and merge rather // than replace when adding their own keys. // // See https://www.cloudwego.io/docs/eino/core_modules/components/document_loader_guide/ // See https://www.cloudwego.io/docs/eino/core_modules/components/document_transformer_guide/ package document ================================================ FILE: components/document/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package document import ( "context" "github.com/cloudwego/eino/schema" ) // Source identifies the external location of a document. // URI can be a local file path or a remote URL reachable by the loader. type Source struct { URI string } //go:generate mockgen -destination ../../internal/mock/components/document/document_mock.go --package document -source interface.go // Loader reads raw content from an external source and returns it as a slice // of [schema.Document] values. // // The Source.URI may be a local file path or a remote URL. The loader is // responsible for fetching the raw bytes; actual format parsing is typically // delegated to a [parser.Parser] configured on the loader via // [WithParserOptions]. // // Document metadata ([schema.Document].MetaData) should be populated with at // least the source URI so that downstream nodes can trace document provenance. type Loader interface { Load(ctx context.Context, src Source, opts ...LoaderOption) ([]*schema.Document, error) } // Transformer converts a slice of [schema.Document] values into another slice, // applying operations such as splitting, filtering, merging, or re-ranking. // // Implementations should preserve existing MetaData keys and merge rather than // replace when adding their own metadata. Downstream nodes (e.g. Indexer, // Retriever) may depend on metadata set by earlier pipeline stages. type Transformer interface { Transform(ctx context.Context, src []*schema.Document, opts ...TransformerOption) ([]*schema.Document, error) } ================================================ FILE: components/document/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package document import "github.com/cloudwego/eino/components/document/parser" // LoaderOptions configures document loaders, including parser options. type LoaderOptions struct { ParserOptions []parser.Option } // LoaderOption defines call option for Loader component, which is part of the component interface signature. // Each Loader implementation could define its own options struct and option funcs within its own package, // then wrap the impl specific option funcs into this type, before passing to Load. type LoaderOption struct { apply func(opts *LoaderOptions) implSpecificOptFn any } // WrapLoaderImplSpecificOptFn wraps the impl specific option functions into LoaderOption type. // T: the type of the impl specific options struct. // Loader implementations are required to use this function to convert its own option functions into the unified LoaderOption type. // For example, if the Loader impl defines its own options struct: // // type customOptions struct { // conf string // } // // Then the impl needs to provide an option function as such: // // func WithConf(conf string) Option { // return WrapLoaderImplSpecificOptFn(func(o *customOptions) { // o.conf = conf // } // } func WrapLoaderImplSpecificOptFn[T any](optFn func(*T)) LoaderOption { return LoaderOption{ implSpecificOptFn: optFn, } } // GetLoaderImplSpecificOptions provides Loader author the ability to extract their own custom options from the unified LoaderOption type. // T: the type of the impl specific options struct. // This function should be used within the Loader implementation's Load function. // It is recommended to provide a base T as the first argument, within which the Loader author can provide default values for the impl specific options. // eg. // // myOption := &MyOption{ // Field1: "default_value", // } // myOption := loader.GetLoaderImplSpecificOptions(myOption, opts...) func GetLoaderImplSpecificOptions[T any](base *T, opts ...LoaderOption) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { s, ok := opt.implSpecificOptFn.(func(*T)) if ok { s(base) } } } return base } // GetLoaderCommonOptions extract loader Options from Option list, optionally providing a base Options with default values. func GetLoaderCommonOptions(base *LoaderOptions, opts ...LoaderOption) *LoaderOptions { if base == nil { base = &LoaderOptions{} } for i := range opts { opt := opts[i] if opt.apply != nil { opt.apply(base) } } return base } // WithParserOptions attaches parser options to a loader request. func WithParserOptions(opts ...parser.Option) LoaderOption { return LoaderOption{ apply: func(o *LoaderOptions) { o.ParserOptions = opts }, } } // TransformerOption defines call option for Transformer component, which is part of the component interface signature. // Each Transformer implementation could define its own options struct and option funcs within its own package, // then wrap the impl specific option funcs into this type, before passing to Transform. type TransformerOption struct { implSpecificOptFn any } // WrapTransformerImplSpecificOptFn wraps the impl specific option functions into TransformerOption type. // T: the type of the impl specific options struct. // Transformer implementations are required to use this function to convert its own option functions into the unified TransformerOption type. // For example, if the Transformer impl defines its own options struct: // // type customOptions struct { // conf string // } // // Then the impl needs to provide an option function as such: // // func WithConf(conf string) TransformerOption { // return WrapTransformerImplSpecificOptFn(func(o *customOptions) { // o.conf = conf // } // } // // . func WrapTransformerImplSpecificOptFn[T any](optFn func(*T)) TransformerOption { return TransformerOption{ implSpecificOptFn: optFn, } } // GetTransformerImplSpecificOptions provides Transformer author the ability to extract their own custom options from the unified TransformerOption type. // T: the type of the impl specific options struct. // This function should be used within the Transformer implementation's Transform function. // It is recommended to provide a base T as the first argument, within which the Transformer author can provide default values for the impl specific options. // eg. // // myOption := &MyOption{ // Field1: "default_value", // } // myOption := transformer.GetTransformerImplSpecificOptions(myOption, opts...) func GetTransformerImplSpecificOptions[T any](base *T, opts ...TransformerOption) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { s, ok := opt.implSpecificOptFn.(func(*T)) if ok { s(base) } } } return base } ================================================ FILE: components/document/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package document import ( "testing" "github.com/smartystreets/goconvey/convey" "github.com/cloudwego/eino/components/document/parser" ) func TestImplSpecificOpts(t *testing.T) { type implSpecificOptions struct { conf string index int } withConf := func(conf string) func(o *implSpecificOptions) { return func(o *implSpecificOptions) { o.conf = conf } } withIndex := func(index int) func(o *implSpecificOptions) { return func(o *implSpecificOptions) { o.index = index } } convey.Convey("TestLoaderImplSpecificOpts", t, func() { documentOption1 := WrapLoaderImplSpecificOptFn(withConf("test_conf")) documentOption2 := WrapLoaderImplSpecificOptFn(withIndex(1)) implSpecificOpts := GetLoaderImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ conf: "test_conf", index: 1, }) }) convey.Convey("TestTransformerImplSpecificOpts", t, func() { documentOption1 := WrapTransformerImplSpecificOptFn(withConf("test_conf")) documentOption2 := WrapTransformerImplSpecificOptFn(withIndex(1)) implSpecificOpts := GetTransformerImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ conf: "test_conf", index: 1, }) }) } func TestCommonOptions(t *testing.T) { convey.Convey("TestCommonOptions", t, func() { o := &LoaderOptions{ParserOptions: []parser.Option{{}}} o1 := GetLoaderCommonOptions(o) convey.So(len(o1.ParserOptions), convey.ShouldEqual, 1) o2 := GetLoaderCommonOptions(o, WithParserOptions(parser.Option{}, parser.Option{})) convey.So(len(o2.ParserOptions), convey.ShouldEqual, 2) }) } ================================================ FILE: components/document/parser/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package parser defines the Parser interface for converting raw byte streams // into [schema.Document] values. // // # Overview // // A Parser is not a standalone pipeline component — it is used inside a // [document.Loader] to handle format-specific decoding. The loader fetches // raw bytes; the parser converts them into documents. // // # Built-in Implementations // // - TextParser: treats the entire reader as plain text, one document per call // - ExtParser: selects a parser by file extension (from [Options.URI]), with // a configurable fallback for unknown extensions // // Use ExtParser when you want format-agnostic loading: pass the source URI // via [WithURI] and ExtParser picks the right sub-parser automatically. // // # Reader Contract // // The [io.Reader] passed to [Parser.Parse] is consumed during the call — // it cannot be read again. Loaders must not reuse the same reader across // multiple Parse calls. // // # Metadata Propagation // // Use [WithExtraMeta] to attach key-value pairs that are merged into every // document's MetaData. This is the standard way to tag documents with source // information (URI, content type, etc.) at parse time. // // See https://www.cloudwego.io/docs/eino/core_modules/components/document_loader_guide/document_parser_interface_guide/ package parser ================================================ FILE: components/document/parser/ext_parser.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parser import ( "context" "errors" "io" "path/filepath" "github.com/cloudwego/eino/schema" ) // ExtParserConfig defines the configuration for the ExtParser. type ExtParserConfig struct { // ext -> parser. // eg: map[string]Parser{ // ".pdf": &PDFParser{}, // ".md": &MarkdownParser{}, // } Parsers map[string]Parser // Fallback parser to use when no other parser is found. // Default is TextParser if not set. FallbackParser Parser } // ExtParser is a parser that uses the file extension to determine which parser to use. // You can register your own parsers by calling RegisterParser. // Default parser is TextParser. // Note: // // parse 时,是通过 filepath.Ext(uri) 的方式找到对应的 parser,因此使用时需要: // ① 必须使用 parser.WithURI 在请求时传入 URI // ② URI 必须能通过 filepath.Ext 来解析出符合预期的 ext // // eg: // // pdf, _ := os.Open("./testdata/test.pdf") // docs, err := ExtParser.Parse(ctx, pdf, parser.WithURI("./testdata/test.pdf")) type ExtParser struct { parsers map[string]Parser fallbackParser Parser } // NewExtParser creates a new ExtParser. func NewExtParser(ctx context.Context, conf *ExtParserConfig) (*ExtParser, error) { if conf == nil { conf = &ExtParserConfig{} } p := &ExtParser{ parsers: conf.Parsers, fallbackParser: conf.FallbackParser, } if p.fallbackParser == nil { p.fallbackParser = TextParser{} } if p.parsers == nil { p.parsers = make(map[string]Parser) } return p, nil } // GetParsers returns a copy of the registered parsers. // It is safe to modify the returned parsers. func (p *ExtParser) GetParsers() map[string]Parser { res := make(map[string]Parser, len(p.parsers)) for k, v := range p.parsers { res[k] = v } return res } // Parse parses the given reader and returns a list of documents. func (p *ExtParser) Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) { opt := GetCommonOptions(&Options{}, opts...) ext := filepath.Ext(opt.URI) parser, ok := p.parsers[ext] if !ok { parser = p.fallbackParser } if parser == nil { return nil, errors.New("no parser found for extension " + ext) } docs, err := parser.Parse(ctx, reader, opts...) if err != nil { return nil, err } for _, doc := range docs { if doc == nil { continue } if doc.MetaData == nil { doc.MetaData = make(map[string]any) } for k, v := range opt.ExtraMeta { doc.MetaData[k] = v } } return docs, nil } ================================================ FILE: components/document/parser/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parser import ( "context" "io" "github.com/cloudwego/eino/schema" ) // Parser converts raw content from an [io.Reader] into [schema.Document] values. // // Parse may return multiple documents from a single reader (e.g. a PDF with // per-page splitting). The reader is consumed during Parse and must not be // reused. // // Parsers are typically not called directly — they are configured on a // [document.Loader] and invoked via [document.WithParserOptions]. type Parser interface { Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) } ================================================ FILE: components/document/parser/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parser // Options configures the document parser with source URI and extra metadata. type Options struct { // uri of source. URI string // extra metadata will merge to each document. ExtraMeta map[string]any } // Option defines call option for Parser component, which is part of the component interface signature. // Each Parser implementation could define its own options struct and option funcs within its own package, // then wrap the impl specific option funcs into this type, before passing to Transform. type Option struct { apply func(opts *Options) implSpecificOptFn any } // WithURI specifies the source URI of the document. // It will be used as to select parser in ExtParser. func WithURI(uri string) Option { return Option{ apply: func(opts *Options) { opts.URI = uri }, } } // WithExtraMeta attaches extra metadata to the parsed document. func WithExtraMeta(meta map[string]any) Option { return Option{ apply: func(opts *Options) { opts.ExtraMeta = meta }, } } // GetCommonOptions extract parser Options from Option list, optionally providing a base Options with default values. func GetCommonOptions(base *Options, opts ...Option) *Options { if base == nil { base = &Options{} } for i := range opts { opt := opts[i] if opt.apply != nil { opt.apply(base) } } return base } // WrapImplSpecificOptFn wraps the impl specific option functions into Option type. // T: the type of the impl specific options struct. // Parser implementations are required to use this function to convert its own option functions into the unified Option type. // For example, if the Parser impl defines its own options struct: // // type customOptions struct { // conf string // } // // Then the impl needs to provide an option function as such: // // func WithConf(conf string) Option { // return WrapImplSpecificOptFn(func(o *customOptions) { // o.conf = conf // } // } // // . func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions provides Parser author the ability to extract their own custom options from the unified Option type. // T: the type of the impl specific options struct. // This function should be used within the Parser implementation's Transform function. // It is recommended to provide a base T as the first argument, within which the Parser author can provide default values for the impl specific options. func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { s, ok := opt.implSpecificOptFn.(func(*T)) if ok { s(base) } } } return base } ================================================ FILE: components/document/parser/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parser import ( "testing" "github.com/smartystreets/goconvey/convey" ) func TestImplSpecificOpts(t *testing.T) { type implSpecificOptions struct { conf string index int } withConf := func(conf string) func(o *implSpecificOptions) { return func(o *implSpecificOptions) { o.conf = conf } } withIndex := func(index int) func(o *implSpecificOptions) { return func(o *implSpecificOptions) { o.index = index } } convey.Convey("TestImplSpecificOpts", t, func() { parserOption1 := WrapImplSpecificOptFn(withConf("test_conf")) parserOption2 := WrapImplSpecificOptFn(withIndex(1)) implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, parserOption1, parserOption2) convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ conf: "test_conf", index: 1, }) }) } ================================================ FILE: components/document/parser/parser_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parser import ( "context" "io" "os" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) type ParserForTest struct { mock func() ([]*schema.Document, error) } func (p *ParserForTest) Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) { return p.mock() } func TestParser(t *testing.T) { ctx := context.Background() t.Run("Test default parser", func(t *testing.T) { conf := &ExtParserConfig{} p, err := NewExtParser(ctx, conf) if err != nil { t.Fatal(err) } f, err := os.Open("testdata/test.md") if err != nil { t.Fatal(err) } defer f.Close() docs, err := p.Parse(ctx, f, WithURI("testdata/test.md")) if err != nil { t.Fatal(err) } assert.Equal(t, 1, len(docs)) assert.Equal(t, "# Title\nhello world", docs[0].Content) }) t.Run("test types", func(t *testing.T) { mockParser := &ParserForTest{ mock: func() ([]*schema.Document, error) { return []*schema.Document{ { Content: "hello world", MetaData: map[string]any{ "type": "text", }, }, }, nil }, } conf := &ExtParserConfig{ Parsers: map[string]Parser{ ".md": mockParser, }, } p, err := NewExtParser(ctx, conf) if err != nil { t.Fatal(err) } f, err := os.Open("testdata/test.md") if err != nil { t.Fatal(err) } defer f.Close() docs, err := p.Parse(ctx, f, WithURI("x/test.md")) if err != nil { t.Fatal(err) } assert.Equal(t, 1, len(docs)) assert.Equal(t, "hello world", docs[0].Content) assert.Equal(t, "text", docs[0].MetaData["type"]) }) t.Run("test get parsers", func(t *testing.T) { p, err := NewExtParser(ctx, &ExtParserConfig{ Parsers: map[string]Parser{ ".md": &TextParser{}, }, }) if err != nil { t.Fatal(err) } ps := p.GetParsers() assert.Equal(t, 1, len(ps)) }) } ================================================ FILE: components/document/parser/testdata/test.md ================================================ # Title hello world ================================================ FILE: components/document/parser/text_parser.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parser import ( "context" "io" "github.com/cloudwego/eino/schema" ) const ( // MetaKeySource is the metadata key storing the document's source URI. MetaKeySource = "_source" ) // TextParser is a simple parser that reads the text from a reader and returns a single document. // eg: // // docs, err := TextParser.Parse(ctx, strings.NewReader("hello world")) // fmt.Println(docs[0].Content) // "hello world" type TextParser struct{} // Parse reads the text from a reader and returns a single document. func (dp TextParser) Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) { data, err := io.ReadAll(reader) if err != nil { return nil, err } opt := GetCommonOptions(&Options{}, opts...) meta := make(map[string]any) meta[MetaKeySource] = opt.URI for k, v := range opt.ExtraMeta { meta[k] = v } doc := &schema.Document{ Content: string(data), MetaData: meta, } return []*schema.Document{doc}, nil } ================================================ FILE: components/embedding/callback_extra.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package embedding import ( "github.com/cloudwego/eino/callbacks" ) // TokenUsage is the token usage for the embedding. type TokenUsage struct { // PromptTokens is the number of prompt tokens. PromptTokens int // CompletionTokens is the number of completion tokens. CompletionTokens int // TotalTokens is the total number of tokens. TotalTokens int } // Config is the config for the embedding. type Config struct { // Model is the model name. Model string // EncodingFormat is the encoding format. EncodingFormat string } // ComponentExtra is the extra information for the embedding. type ComponentExtra struct { // Config is the config for the embedding. Config *Config // TokenUsage is the token usage for the embedding. TokenUsage *TokenUsage } // CallbackInput is the input for the embedding callback. type CallbackInput struct { // Texts is the texts to be embedded. Texts []string // Config is the config for the embedding. Config *Config // Extra is the extra information for the callback. Extra map[string]any } // CallbackOutput is the output for the embedding callback. type CallbackOutput struct { // Embeddings is the embeddings. Embeddings [][]float64 // Config is the config for creating the embedding. Config *Config // TokenUsage is the token usage for the embedding. TokenUsage *TokenUsage // Extra is the extra information for the callback. Extra map[string]any } // ConvCallbackInput converts the callback input to the embedding callback input. func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { switch t := src.(type) { case *CallbackInput: return t case []string: return &CallbackInput{ Texts: t, } default: return nil } } // ConvCallbackOutput converts the callback output to the embedding callback output. func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { switch t := src.(type) { case *CallbackOutput: return t case [][]float64: return &CallbackOutput{ Embeddings: t, } default: return nil } } ================================================ FILE: components/embedding/callback_extra_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package embedding import ( "testing" "github.com/stretchr/testify/assert" ) func TestConvEmbedding(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) assert.NotNil(t, ConvCallbackInput([]string{})) assert.Nil(t, ConvCallbackInput("asd")) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) assert.NotNil(t, ConvCallbackOutput([][]float64{})) assert.Nil(t, ConvCallbackOutput("asd")) } ================================================ FILE: components/embedding/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package embedding defines the Embedder component interface for converting // text into vector representations. // // # Overview // // An Embedder converts a batch of strings into dense float vectors. Semantically // similar texts produce vectors that are close in the vector space, making // embeddings the backbone of semantic search, RAG pipelines, and clustering. // // Concrete implementations (OpenAI, Ark, Ollama, …) live in eino-ext: // // github.com/cloudwego/eino-ext/components/embedding/ // // # Output Format // // [Embedder.EmbedStrings] returns `[][]float64` where: // - outer index corresponds to the input text at the same position // - inner slice is the embedding vector; its length (dimensions) is fixed by // the model and is the same for every text // // # Consistency Requirement // // The same model must be used for both indexing and retrieval. Mixing models // produces vectors in different spaces — similarity scores become meaningless // and semantic search breaks silently. // // See https://www.cloudwego.io/docs/eino/core_modules/components/embedding_guide/ package embedding ================================================ FILE: components/embedding/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package embedding import "context" // Embedder converts a batch of strings into dense vector representations. // // EmbedStrings returns one vector per input text, in the same order. The // vector length (dimensions) is fixed by the underlying model and identical // for every text in the batch. // // The returned [][]float64 maps as: // // embeddings[i] → vector for texts[i] // len(embeddings[i]) → model's embedding dimension (e.g. 1536 for ada-002) // // Both [Indexer] and [Retriever] use an Embedder to convert documents and // queries into vectors. They must share the exact same model — mismatched // dimensions or model families break semantic similarity. // //go:generate mockgen -destination ../../internal/mock/components/embedding/Embedding_mock.go --package embedding -source interface.go type Embedder interface { EmbedStrings(ctx context.Context, texts []string, opts ...Option) ([][]float64, error) // invoke } ================================================ FILE: components/embedding/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package embedding // Options is the options for the embedding. type Options struct { // Model is the model name for the embedding. Model *string } // Option is a call-time option for an Embedder. type Option struct { apply func(opts *Options) implSpecificOptFn any } // WithModel is the option to set the model for the embedding. func WithModel(model string) Option { return Option{ apply: func(opts *Options) { opts.Model = &model }, } } // GetCommonOptions extract embedding Options from Option list, optionally providing a base Options with default values. // eg. // // defaultModelName := "default_model" // embeddingOption := &embedding.Options{ // Model: &defaultModelName, // } // embeddingOption := embedding.GetCommonOptions(embeddingOption, opts...) func GetCommonOptions(base *Options, opts ...Option) *Options { if base == nil { base = &Options{} } for i := range opts { opt := opts[i] if opt.apply != nil { opt.apply(base) } } return base } // WrapImplSpecificOptFn wraps an implementation-specific option function so it // can be passed alongside standard options. For use by Embedder implementors: // // func WithMyParam(v string) embedding.Option { // return embedding.WrapImplSpecificOptFn(func(o *MyOptions) { // o.MyParam = v // }) // } func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions extracts implementation-specific options from opts, // merging them onto base. Call alongside [GetCommonOptions] inside EmbedStrings: // // func (e *MyEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { // common := embedding.GetCommonOptions(nil, opts...) // mine := embedding.GetImplSpecificOptions(&MyOptions{}, opts...) // // use common.Model, mine.MyParam, etc. // } func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { optFn, ok := opt.implSpecificOptFn.(func(*T)) if ok { optFn(base) } } } return base } ================================================ FILE: components/embedding/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package embedding import ( "testing" "github.com/stretchr/testify/assert" ) func TestOptions(t *testing.T) { defaultModel := "default_model" opts := GetCommonOptions(&Options{Model: &defaultModel}, WithModel("test_model")) assert.NotNil(t, opts.Model) assert.Equal(t, *opts.Model, "test_model") } ================================================ FILE: components/indexer/callback_extra.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package indexer import ( "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // CallbackInput is the input for the indexer callback. type CallbackInput struct { // Docs is the documents to be indexed. Docs []*schema.Document // Extra is the extra information for the callback. Extra map[string]any } // CallbackOutput is the output for the indexer callback. type CallbackOutput struct { // IDs is the ids of the indexed documents returned by the indexer. IDs []string // Extra is the extra information for the callback. Extra map[string]any } // ConvCallbackInput converts the callback input to the indexer callback input. func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { switch t := src.(type) { case *CallbackInput: return t case []*schema.Document: return &CallbackInput{ Docs: t, } default: return nil } } // ConvCallbackOutput converts the callback output to the indexer callback output. func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { switch t := src.(type) { case *CallbackOutput: return t case []string: return &CallbackOutput{ IDs: t, } default: return nil } } ================================================ FILE: components/indexer/callback_extra_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package indexer import ( "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestConvIndexer(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) assert.NotNil(t, ConvCallbackInput([]*schema.Document{})) assert.Nil(t, ConvCallbackInput("asd")) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) assert.NotNil(t, ConvCallbackOutput([]string{})) assert.Nil(t, ConvCallbackOutput("asd")) } ================================================ FILE: components/indexer/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package indexer defines the Indexer component interface for storing documents // and their vector representations in a backend store. // // # Overview // // An Indexer is the write path of a RAG pipeline. It takes [schema.Document] // values, optionally generates vector embeddings, and persists them in a // backend (vector DB, search engine, etc.) for later retrieval. // // Concrete implementations (VikingDB, Milvus, Elasticsearch, …) live in // eino-ext: // // github.com/cloudwego/eino-ext/components/indexer/ // // # Vector Dimension Consistency // // When using the [Options.Embedding] option, the embedding model must be // identical to the one used by the paired [retriever.Retriever]. Mismatched // models produce vectors in different spaces — queries will not match stored // documents. // // See https://www.cloudwego.io/docs/eino/core_modules/components/indexer_guide/ package indexer ================================================ FILE: components/indexer/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package indexer import ( "context" "github.com/cloudwego/eino/schema" ) // Indexer stores documents (and optionally their vector embeddings) in a // backend for later retrieval. // // Store accepts a batch of [schema.Document] values and returns the IDs // assigned to them by the backend. When [Options.Embedding] is provided, // the implementation generates vectors before storing — the same embedder // must be used by the paired [retriever.Retriever]. // // Use [Options.SubIndexes] to write documents into logical sub-partitions // within the same store. // //go:generate mockgen -destination ../../internal/mock/components/indexer/indexer_mock.go --package indexer -source interface.go type Indexer interface { // Store stores the documents and returns their assigned IDs. Store(ctx context.Context, docs []*schema.Document, opts ...Option) (ids []string, err error) // invoke } ================================================ FILE: components/indexer/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package indexer import "github.com/cloudwego/eino/components/embedding" // Options is the options for the indexer. type Options struct { // SubIndexes is the sub indexes to be indexed. SubIndexes []string // Embedding is the embedding component. Embedding embedding.Embedder } // WithSubIndexes is the option to set the sub indexes for the indexer. func WithSubIndexes(subIndexes []string) Option { return Option{ apply: func(opts *Options) { opts.SubIndexes = subIndexes }, } } // WithEmbedding is the option to set the embedder for the indexer, which convert document to embeddings. func WithEmbedding(emb embedding.Embedder) Option { return Option{ apply: func(opts *Options) { opts.Embedding = emb }, } } // Option is a call-time option for an Indexer. type Option struct { apply func(opts *Options) implSpecificOptFn any } // GetCommonOptions extracts standard [Options] from opts, merging onto base. // Implementors must call this inside Store: // // func (idx *MyIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { // options := indexer.GetCommonOptions(nil, opts...) // // use options.Embedding to generate vectors before storage // } func GetCommonOptions(base *Options, opts ...Option) *Options { if base == nil { base = &Options{} } for i := range opts { opt := opts[i] if opt.apply != nil { opt.apply(base) } } return base } // WrapImplSpecificOptFn wraps an implementation-specific option function so it // can be passed alongside standard options. For use by Indexer implementors. func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions extracts implementation-specific options from opts, // merging onto base. Call alongside [GetCommonOptions] inside Store. func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { optFn, ok := opt.implSpecificOptFn.(func(*T)) if ok { optFn(base) } } } return base } ================================================ FILE: components/indexer/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package indexer import ( "testing" "github.com/smartystreets/goconvey/convey" "github.com/cloudwego/eino/internal/mock/components/embedding" ) func TestOptions(t *testing.T) { convey.Convey("test options", t, func() { var ( subIndexes = []string{"index_1", "index_2"} e = &embedding.MockEmbedder{} ) opts := GetCommonOptions( &Options{}, WithSubIndexes(subIndexes), WithEmbedding(e), ) convey.So(opts, convey.ShouldResemble, &Options{ SubIndexes: subIndexes, Embedding: e, }) }) } ================================================ FILE: components/model/callback_extra.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model import ( "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // TokenUsage is the token usage for the model. type TokenUsage struct { // PromptTokens is the number of prompt tokens, including all the input tokens of this request. PromptTokens int // PromptTokenDetails is a breakdown of the prompt tokens. PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int // TotalTokens is the total number of tokens. TotalTokens int // CompletionTokensDetails is breakdown of completion tokens. CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. ReasoningTokens int `json:"reasoning_tokens,omitempty"` } // PromptTokenDetails provides a breakdown of prompt token usage. type PromptTokenDetails struct { // Cached tokens present in the prompt. CachedTokens int } // Config is the config for the model. type Config struct { // Model is the model name. Model string // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". MaxTokens int // Temperature is the temperature, which controls the randomness of the model. Temperature float32 // TopP is the top p, which controls the diversity of the model. TopP float32 // Stop is the stop words, which controls the stopping condition of the model. Stop []string } // CallbackInput is the input for the model callback. type CallbackInput struct { // Messages is the messages to be sent to the model. Messages []*schema.Message // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // ToolChoice is the tool choice, which controls the tool to be used in the model. ToolChoice *schema.ToolChoice // Config is the config for the model. Config *Config // Extra is the extra information for the callback. Extra map[string]any } // CallbackOutput is the output for the model callback. type CallbackOutput struct { // Message is the message generated by the model. Message *schema.Message // Config is the config for the model. Config *Config // TokenUsage is the token usage of this request. TokenUsage *TokenUsage // Extra is the extra information for the callback. Extra map[string]any } // ConvCallbackInput converts the callback input to the model callback input. func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { switch t := src.(type) { case *CallbackInput: // when callback is triggered within component implementation, the input is usually already a typed *model.CallbackInput return t case []*schema.Message: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.Message return &CallbackInput{ Messages: t, } default: return nil } } // ConvCallbackOutput converts the callback output to the model callback output. func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { switch t := src.(type) { case *CallbackOutput: // when callback is triggered within component implementation, the output is usually already a typed *model.CallbackOutput return t case *schema.Message: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.Message return &CallbackOutput{ Message: t, } default: return nil } } ================================================ FILE: components/model/callback_extra_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model import ( "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestConvModel(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) assert.NotNil(t, ConvCallbackInput([]*schema.Message{})) assert.Nil(t, ConvCallbackInput("asd")) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) assert.NotNil(t, ConvCallbackOutput(&schema.Message{})) assert.Nil(t, ConvCallbackOutput("asd")) } ================================================ FILE: components/model/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package model defines the ChatModel component interface for interacting with // large language models (LLMs). // // # Overview // // A ChatModel takes a slice of [schema.Message] as input and returns a response // message — either in full ([BaseChatModel.Generate]) or incrementally as a // stream ([BaseChatModel.Stream]). It is the most fundamental building block in // an eino pipeline: every application that talks to an LLM goes through this // interface. // // Concrete implementations (OpenAI, Ark, Ollama, …) live in eino-ext: // // github.com/cloudwego/eino-ext/components/model/ // // # Interface Hierarchy // // BaseChatModel — Generate + Stream (all implementations) // ├── ToolCallingChatModel — preferred; WithTools returns a new instance (concurrency-safe) // └── ChatModel — deprecated; BindTools mutates state (avoid in new code) // // # Choosing Generate vs Stream // // Use [BaseChatModel.Generate] when the full response is needed before // proceeding (e.g. structured extraction, classification). // Use [BaseChatModel.Stream] when output should be forwarded to the caller // incrementally (e.g. chat UI, long-form generation). Always close the // [schema.StreamReader] returned by Stream — failing to do so leaks the // underlying connection: // // reader, err := model.Stream(ctx, messages) // if err != nil { ... } // defer reader.Close() // // # Implementing a ChatModel // // Implementations must call [GetCommonOptions] to extract standard options and // [GetImplSpecificOptions] to extract their own options from the Option list. // Expose implementation-specific options via [WrapImplSpecificOptFn]. // // See https://www.cloudwego.io/docs/eino/core_modules/components/chat_model_guide/ // for the full component guide. package model ================================================ FILE: components/model/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model import ( "context" "github.com/cloudwego/eino/schema" ) // BaseChatModel defines the core interface for all chat model implementations. // // It exposes two modes of interaction: // - [BaseChatModel.Generate]: blocks until the model returns a complete response. // - [BaseChatModel.Stream]: returns a [schema.StreamReader] that yields message // chunks incrementally as the model generates them. // // The input is a slice of [schema.Message] representing the conversation history. // Messages carry a role (system, user, assistant, tool) and support multimodal // content (text, images, audio, video). Tool messages must include a ToolCallID // that correlates them with a prior assistant tool-call message. // // Stream usage — the caller is responsible for closing the reader: // // reader, err := m.Stream(ctx, messages) // if err != nil { ... } // defer reader.Close() // for { // chunk, err := reader.Recv() // if errors.Is(err, io.EOF) { break } // if err != nil { ... } // // handle chunk // } // // Note: a [schema.StreamReader] can only be read once. If multiple consumers // need the stream, it must be copied before reading. // //go:generate mockgen -destination ../../internal/mock/components/model/ChatModel_mock.go --package model -source interface.go type BaseChatModel interface { Generate(ctx context.Context, input []*schema.Message, opts ...Option) (*schema.Message, error) Stream(ctx context.Context, input []*schema.Message, opts ...Option) ( *schema.StreamReader[*schema.Message], error) } // Deprecated: Use [ToolCallingChatModel] instead. // // ChatModel extends [BaseChatModel] with tool binding via [ChatModel.BindTools]. // BindTools mutates the instance in place, which causes a race condition when // the same instance is used concurrently: one goroutine's tool list can // overwrite another's. Prefer [ToolCallingChatModel.WithTools], which returns // a new immutable instance and is safe for concurrent use. type ChatModel interface { BaseChatModel // BindTools bind tools to the model. // BindTools before requesting ChatModel generally. // notice the non-atomic problem of BindTools and Generate. BindTools(tools []*schema.ToolInfo) error } // ToolCallingChatModel extends [BaseChatModel] with safe tool binding. // // Unlike the deprecated [ChatModel.BindTools], [ToolCallingChatModel.WithTools] // does not mutate the receiver — it returns a new instance with the given tools // attached. This makes it safe to share a base model instance across goroutines // and derive per-request variants with different tool sets: // // base, _ := openai.NewChatModel(ctx, cfg) // shared, no tools // withSearch, _ := base.WithTools([]*schema.ToolInfo{searchTool}) // withCalc, _ := base.WithTools([]*schema.ToolInfo{calcTool}) type ToolCallingChatModel interface { BaseChatModel // WithTools returns a new ToolCallingChatModel instance with the specified tools bound. // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (ToolCallingChatModel, error) } ================================================ FILE: components/model/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model import "github.com/cloudwego/eino/schema" // Options is the common options for the model. type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. Temperature *float32 // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". MaxTokens *int // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. TopP *float32 // Stop is the stop words for the model, which controls the stopping condition of the model. Stop []string // Tools is a list of tools the model may call. Tools []*schema.ToolInfo // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string } // Option is a call-time option for a ChatModel. Options are immutable and // composable: each Option carries either a common-option setter (applied via // [GetCommonOptions]) or an implementation-specific setter (applied via // [GetImplSpecificOptions]), never both. type Option struct { apply func(opts *Options) implSpecificOptFn any } // WithTemperature is the option to set the temperature for the model. func WithTemperature(temperature float32) Option { return Option{ apply: func(opts *Options) { opts.Temperature = &temperature }, } } // WithMaxTokens is the option to set the max tokens for the model. func WithMaxTokens(maxTokens int) Option { return Option{ apply: func(opts *Options) { opts.MaxTokens = &maxTokens }, } } // WithModel is the option to set the model name. func WithModel(name string) Option { return Option{ apply: func(opts *Options) { opts.Model = &name }, } } // WithTopP is the option to set the top p for the model. func WithTopP(topP float32) Option { return Option{ apply: func(opts *Options) { opts.TopP = &topP }, } } // WithStop is the option to set the stop words for the model. func WithStop(stop []string) Option { return Option{ apply: func(opts *Options) { opts.Stop = stop }, } } // WithTools is the option to set tools for the model. func WithTools(tools []*schema.ToolInfo) Option { if tools == nil { tools = []*schema.ToolInfo{} } return Option{ apply: func(opts *Options) { opts.Tools = tools }, } } // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Option { return Option{ apply: func(opts *Options) { opts.ToolChoice = &toolChoice opts.AllowedToolNames = allowedToolNames }, } } // WrapImplSpecificOptFn wraps an implementation-specific option function into // an [Option] so it can be passed alongside standard options. // // This is intended for ChatModel implementors, not callers. Define a typed // setter for your own config struct and expose it as an Option: // // // In your implementation package: // func WithMyParam(v string) model.Option { // return model.WrapImplSpecificOptFn(func(o *MyOptions) { // o.MyParam = v // }) // } // // Callers can then mix standard and implementation-specific options freely: // // model.Generate(ctx, msgs, // model.WithTemperature(0.7), // mypkg.WithMyParam("value"), // ) func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ implSpecificOptFn: optFn, } } // GetCommonOptions extracts standard [Options] from an Option list, merging // them onto base. If base is nil, a zero-value Options is used. // // Implementors must call this to honour options passed by callers: // // func (m *MyModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { // options := model.GetCommonOptions(&model.Options{Temperature: &m.defaultTemp}, opts...) // // use options.Temperature, options.Tools, etc. // } func GetCommonOptions(base *Options, opts ...Option) *Options { if base == nil { base = &Options{} } for i := range opts { opt := opts[i] if opt.apply != nil { opt.apply(base) } } return base } // GetImplSpecificOptions extracts implementation-specific options from an // Option list, merging them onto base. If base is nil, a zero-value T is used. // // Call this alongside [GetCommonOptions] to support both standard and custom // options in your implementation: // // type MyOptions struct { MyParam string } // // func (m *MyModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { // common := model.GetCommonOptions(nil, opts...) // myOpts := model.GetImplSpecificOptions(&MyOptions{MyParam: "default"}, opts...) // // use common.Temperature, myOpts.MyParam, etc. // } func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { optFn, ok := opt.implSpecificOptFn.(func(*T)) if ok { optFn(base) } } } return base } ================================================ FILE: components/model/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model import ( "testing" "github.com/smartystreets/goconvey/convey" "github.com/cloudwego/eino/schema" ) func TestOptions(t *testing.T) { convey.Convey("test options", t, func() { var ( modelName = "model" temperature float32 = 0.9 maxToken = 5000 topP float32 = 0.8 defaultModel = "default_model" defaultTemperature float32 = 1.0 defaultMaxTokens = 1000 defaultTopP float32 = 0.5 tools = []*schema.ToolInfo{{Name: "asd"}, {Name: "qwe"}} toolChoice = schema.ToolChoiceForced allowedToolNames = []string{"web_search"} ) opts := GetCommonOptions( &Options{ Model: &defaultModel, Temperature: &defaultTemperature, MaxTokens: &defaultMaxTokens, TopP: &defaultTopP, }, WithModel(modelName), WithTemperature(temperature), WithMaxTokens(maxToken), WithTopP(topP), WithStop([]string{"hello", "bye"}), WithTools(tools), WithToolChoice(toolChoice, allowedToolNames...), ) convey.So(opts, convey.ShouldResemble, &Options{ Model: &modelName, Temperature: &temperature, MaxTokens: &maxToken, TopP: &topP, Stop: []string{"hello", "bye"}, Tools: tools, ToolChoice: &toolChoice, AllowedToolNames: allowedToolNames, }) }) convey.Convey("test nil tools option", t, func() { opts := GetCommonOptions( &Options{ Tools: []*schema.ToolInfo{ {Name: "asd"}, {Name: "qwe"}, }, }, WithTools(nil), ) convey.So(opts.Tools, convey.ShouldNotBeNil) convey.So(len(opts.Tools), convey.ShouldEqual, 0) }) } type implOption struct { userID int64 name string } func WithUserID(uid int64) Option { return WrapImplSpecificOptFn[implOption](func(i *implOption) { i.userID = uid }) } func WithName(n string) Option { return WrapImplSpecificOptFn[implOption](func(i *implOption) { i.name = n }) } func TestImplSpecificOption(t *testing.T) { convey.Convey("impl_specific_option", t, func() { opt := GetImplSpecificOptions(&implOption{}, WithUserID(101), WithName("Wang")) convey.So(opt, convey.ShouldEqual, &implOption{ userID: 101, name: "Wang", }) }) } ================================================ FILE: components/prompt/callback_extra.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prompt import ( "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // CallbackInput is the input for the callback. type CallbackInput struct { // Variables is the variables for the callback. Variables map[string]any // Templates is the templates for the callback. Templates []schema.MessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } // CallbackOutput is the output for the callback. type CallbackOutput struct { // Result is the result for the callback. Result []*schema.Message // Templates is the templates for the callback. Templates []schema.MessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } // ConvCallbackInput converts the callback input to the prompt callback input. func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { switch t := src.(type) { case *CallbackInput: return t case map[string]any: return &CallbackInput{ Variables: t, } default: return nil } } // ConvCallbackOutput converts the callback output to the prompt callback output. func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { switch t := src.(type) { case *CallbackOutput: return t case []*schema.Message: return &CallbackOutput{ Result: t, } default: return nil } } ================================================ FILE: components/prompt/callback_extra_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prompt import ( "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestConvPrompt(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) assert.Nil(t, ConvCallbackOutput("asd")) } ================================================ FILE: components/prompt/chat_template.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prompt import ( "context" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/schema" ) // DefaultChatTemplate is the default chat template implementation. type DefaultChatTemplate struct { // templates is the templates for the chat template. templates []schema.MessagesTemplate // formatType is the format type for the chat template. formatType schema.FormatType } // FromMessages creates a new DefaultChatTemplate from the given templates and format type. // eg. // // template := prompt.FromMessages(schema.FString, &schema.Message{Content: "Hello, {name}!"}, &schema.Message{Content: "how are you?"}) // // in chain, or graph // chain := compose.NewChain[map[string]any, []*schema.Message]() // chain.AppendChatTemplate(template) func FromMessages(formatType schema.FormatType, templates ...schema.MessagesTemplate) *DefaultChatTemplate { return &DefaultChatTemplate{ templates: templates, formatType: formatType, } } // Format formats the chat template with the given context and variables. func (t *DefaultChatTemplate) Format(ctx context.Context, vs map[string]any, _ ...Option) (result []*schema.Message, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfPrompt) ctx = callbacks.OnStart(ctx, &CallbackInput{ Variables: vs, Templates: t.templates, }) defer func() { if err != nil { _ = callbacks.OnError(ctx, err) } }() result = make([]*schema.Message, 0, len(t.templates)) for _, template := range t.templates { msgs, err := template.Format(ctx, vs, t.formatType) if err != nil { return nil, err } result = append(result, msgs...) } _ = callbacks.OnEnd(ctx, &CallbackOutput{ Result: result, Templates: t.templates, }) return result, nil } // GetType returns the type of the chat template (Default). func (t *DefaultChatTemplate) GetType() string { return "Default" } // IsCallbacksEnabled checks if the callbacks are enabled for the chat template. func (t *DefaultChatTemplate) IsCallbacksEnabled() bool { return true } ================================================ FILE: components/prompt/chat_template_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prompt import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestFormat(t *testing.T) { pyFmtTestTemplate := []schema.MessagesTemplate{ schema.SystemMessage( "you are a helpful assistant.\n" + "here is the context: {context}"), schema.MessagesPlaceholder("chat_history", true), schema.UserMessage("question: {question}"), } jinja2TestTemplate := []schema.MessagesTemplate{ schema.SystemMessage( "you are a helpful assistant.\n" + "here is the context: {{context}}"), schema.MessagesPlaceholder("chat_history", true), schema.UserMessage("question: {{question}}"), } goFmtTestTemplate := []schema.MessagesTemplate{ schema.SystemMessage( "you are a helpful assistant.\n" + "here is the context: {{.context}}"), schema.MessagesPlaceholder("chat_history", true), schema.UserMessage("question: {{.question}}"), } testValues := map[string]any{ "context": "it's beautiful day", "question": "how is the day today", "chat_history": []*schema.Message{ schema.UserMessage("who are you"), schema.AssistantMessage("I'm a helpful assistant", nil), }, } expected := []*schema.Message{ schema.SystemMessage( "you are a helpful assistant.\n" + "here is the context: it's beautiful day"), schema.UserMessage("who are you"), schema.AssistantMessage("I'm a helpful assistant", nil), schema.UserMessage("question: how is the day today"), } // FString chatTemplate := FromMessages(schema.FString, pyFmtTestTemplate...) msgs, err := chatTemplate.Format(context.Background(), testValues) assert.Nil(t, err) assert.Equal(t, expected, msgs) // Jinja2 chatTemplate = FromMessages(schema.Jinja2, jinja2TestTemplate...) msgs, err = chatTemplate.Format(context.Background(), testValues) assert.Nil(t, err) assert.Equal(t, expected, msgs) // GoTemplate chatTemplate = FromMessages(schema.GoTemplate, goFmtTestTemplate...) msgs, err = chatTemplate.Format(context.Background(), testValues) assert.Nil(t, err) assert.Equal(t, expected, msgs) } func TestDocumentFormat(t *testing.T) { docs := []*schema.Document{ { ID: "1", Content: "qwe", MetaData: map[string]any{ "hello": 888, }, }, { ID: "2", Content: "asd", MetaData: map[string]any{ "bye": 111, }, }, } template := FromMessages(schema.FString, schema.SystemMessage("all:{all_docs}\nsingle:{single_doc}"), ) msgs, err := template.Format(context.Background(), map[string]any{ "all_docs": docs, "single_doc": docs[0], }) assert.Nil(t, err) t.Log(msgs) } func TestMultiContentFormat(t *testing.T) { mtpl := []schema.MessagesTemplate{ &schema.Message{ Content: "{a}", MultiContent: []schema.ChatMessagePart{ { Type: schema.ChatMessagePartTypeText, Text: "{b}", }, { Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{ URL: "{c}", }, }, { Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{ URL: "{d}", }, }, { Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{ URL: "{e}", }, }, { Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{ URL: "{f}", }, }, }, }, } input := map[string]any{ "a": "content", "b": "text", "c": "image url", "d": "audio url", "e": "video url", "f": "file url", } expected := []*schema.Message{ { Content: "content", MultiContent: []schema.ChatMessagePart{ { Type: schema.ChatMessagePartTypeText, Text: "text", }, { Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{ URL: "image url", }, }, { Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{ URL: "audio url", }, }, { Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{ URL: "video url", }, }, { Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{ URL: "file url", }, }, }, }, } tpl := FromMessages(schema.FString, mtpl...) result, err := tpl.Format(context.Background(), input) assert.Nil(t, err) assert.Equal(t, expected, result) } ================================================ FILE: components/prompt/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package prompt defines the ChatTemplate component interface for building // structured message lists from templates and runtime variables. // // # Overview // // A ChatTemplate takes a variables map and produces a []*schema.Message slice // ready to pass to a [model.BaseChatModel]. It is typically the first node in // a pipeline, sitting before the ChatModel. // // The built-in [DefaultChatTemplate] supports three template syntaxes: // - FString: {variable} substitution // - GoTemplate: Go's text/template with conditionals and loops // - Jinja2: Jinja2 template syntax // // # Construction // // Use [FromMessages] to build a template from a list of message templates: // // tmpl := prompt.FromMessages(schema.FString, // schema.SystemMessage("You are a helpful assistant."), // schema.UserMessage("Answer this: {question}"), // ) // msgs, err := tmpl.Format(ctx, map[string]any{"question": "What is eino?"}) // // Use [schema.MessagesPlaceholder] to insert a dynamic list of messages // (e.g. conversation history) at a fixed position in the template: // // tmpl := prompt.FromMessages(schema.FString, // schema.SystemMessage("You are a helpful assistant."), // schema.MessagesPlaceholder("history", true), // schema.UserMessage("{question}"), // ) // // # Common Pitfall // // Variable mismatches (a key present in the template but missing from the // variables map) produce a runtime error — there is no compile-time check. // // See https://www.cloudwego.io/docs/eino/core_modules/components/chat_template_guide/ package prompt ================================================ FILE: components/prompt/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prompt import ( "context" "github.com/cloudwego/eino/schema" ) var _ ChatTemplate = &DefaultChatTemplate{} // ChatTemplate formats a variables map into a list of messages for a ChatModel. // // Format substitutes the values from vs into the template's message list and // returns the resulting []*schema.Message. The exact substitution syntax // (FString, GoTemplate, Jinja2) is determined at construction time. // // Variable keys present in the template but absent from vs produce a runtime // error — there is no compile-time safety. Prefer consistent variable naming // across templates and callers. // // In a Graph or Chain, ChatTemplate typically precedes ChatModel. Use // compose.WithOutputKey to convert the prior node's output into the map[string]any // that Format expects. // // See [FromMessages] and [schema.MessagesPlaceholder] for construction helpers. type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } ================================================ FILE: components/prompt/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prompt // Option is a call-time option for a ChatTemplate. The built-in // [DefaultChatTemplate] has no common options — this type exists primarily for // custom ChatTemplate implementations that need per-call configuration. type Option struct { implSpecificOptFn any } // WrapImplSpecificOptFn wraps an implementation-specific option function so it // can be passed alongside any future standard options. For use by custom // ChatTemplate implementors. func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions extracts the implementation specific options from Option list, optionally providing a base options with default values. func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { s, ok := opt.implSpecificOptFn.(func(*T)) if ok { s(base) } } } return base } ================================================ FILE: components/prompt/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package prompt import ( "testing" "github.com/smartystreets/goconvey/convey" ) type implOption struct { userID int64 name string } func WithUserID(uid int64) Option { return WrapImplSpecificOptFn[implOption](func(i *implOption) { i.userID = uid }) } func WithName(n string) Option { return WrapImplSpecificOptFn[implOption](func(i *implOption) { i.name = n }) } func TestImplSpecificOption(t *testing.T) { convey.Convey("impl_specific_option", t, func() { opt := GetImplSpecificOptions(&implOption{}, WithUserID(101), WithName("Wang")) convey.So(opt, convey.ShouldEqual, &implOption{ userID: 101, name: "Wang", }) }) } ================================================ FILE: components/retriever/callback_extra.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retriever import ( "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // CallbackInput is the input for the retriever callback. type CallbackInput struct { // Query is the query for the retriever. Query string // TopK is the top k for the retriever, which means the top number of documents to retrieve. TopK int // Filter is the filter for the retriever. Filter string // ScoreThreshold is the score threshold for the retriever, eg 0.5 means the score of the document must be greater than 0.5. ScoreThreshold *float64 // Extra is the extra information for the retriever. Extra map[string]any } // CallbackOutput is the output for the retriever callback. type CallbackOutput struct { // Docs is the documents for the retriever. Docs []*schema.Document // Extra is the extra information for the retriever. Extra map[string]any } // ConvCallbackInput converts the callback input to the retriever callback input. func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { switch t := src.(type) { case *CallbackInput: return t case string: return &CallbackInput{ Query: t, } default: return nil } } // ConvCallbackOutput converts the callback output to the retriever callback output. func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { switch t := src.(type) { case *CallbackOutput: return t case []*schema.Document: return &CallbackOutput{ Docs: t, } default: return nil } } ================================================ FILE: components/retriever/callback_extra_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retriever import ( "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestConvRetriever(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) assert.NotNil(t, ConvCallbackInput("asd")) assert.Nil(t, ConvCallbackInput([]string{})) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) assert.NotNil(t, ConvCallbackOutput([]*schema.Document{})) assert.Nil(t, ConvCallbackOutput("asd")) } ================================================ FILE: components/retriever/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package retriever defines the Retriever component interface for fetching // relevant documents from a document store given a query. // // # Overview // // A Retriever is the read path of a RAG (Retrieval-Augmented Generation) // pipeline. Given a query string it returns the most relevant [schema.Document] // values from an underlying store (vector DB, keyword index, etc.). // // Concrete implementations (VikingDB, Milvus, Elasticsearch, …) live in // eino-ext: // // github.com/cloudwego/eino-ext/components/retriever/ // // # Relationship to Indexer // // [Indexer] and Retriever are complementary: // - Indexer writes documents (and their vectors) to the store // - Retriever reads them back // // When both use an [embedding.Embedder], it must be the same model — vector // dimensions must match or similarity scores will be meaningless. // // # Result Ordering // // Results are ordered by relevance score (descending). Scores and other // backend metadata are available via [schema.Document].MetaData. // // See https://www.cloudwego.io/docs/eino/core_modules/components/retriever_guide/ package retriever ================================================ FILE: components/retriever/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retriever import ( "context" "github.com/cloudwego/eino/schema" ) //go:generate mockgen -destination ../../internal/mock/components/retriever/retriever_mock.go --package retriever -source interface.go // Retriever fetches the most relevant documents from a store for a given query. // // Retrieve accepts a natural-language query string and returns matching // [schema.Document] values ordered by relevance (most relevant first). // Relevance scores and backend-specific metadata are available in // [schema.Document].MetaData. // // When [Options.Embedding] is set, the implementation converts the query to a // vector before searching. The embedder must be the same model used at index // time — see [indexer.Options.Embedding]. // // [Options.ScoreThreshold] is a filter, not a sort: documents scoring below // the threshold are excluded entirely. [Options.TopK] caps the number of // results returned. // // Retrieve can be used standalone or added to a Graph via AddRetrieverNode: // // retriever, _ := redis.NewRetriever(ctx, cfg) // docs, _ := retriever.Retrieve(ctx, "what is eino?", retriever.WithTopK(5)) // // graph.AddRetrieverNode("retriever", retriever) type Retriever interface { Retrieve(ctx context.Context, query string, opts ...Option) ([]*schema.Document, error) } ================================================ FILE: components/retriever/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retriever import "github.com/cloudwego/eino/components/embedding" // Options is the options for the retriever. type Options struct { // Index is the index for the retriever, index in different retriever may be different. Index *string // SubIndex is the sub index for the retriever, sub index in different retriever may be different. SubIndex *string // TopK is the top k for the retriever, which means the top number of documents to retrieve. TopK *int // ScoreThreshold is the score threshold for the retriever, eg 0.5 means the score of the document must be greater than 0.5. ScoreThreshold *float64 // Embedding is the embedder for the retriever, which is used to embed the query for retrieval . Embedding embedding.Embedder // DSLInfo carries backend-specific filter/query expressions. The structure and // semantics are defined by the underlying store implementation. DSLInfo map[string]any } // WithIndex wraps the index option. func WithIndex(index string) Option { return Option{ apply: func(opts *Options) { opts.Index = &index }, } } // WithSubIndex wraps the sub index option. func WithSubIndex(subIndex string) Option { return Option{ apply: func(opts *Options) { opts.SubIndex = &subIndex }, } } // WithTopK wraps the top k option. func WithTopK(topK int) Option { return Option{ apply: func(opts *Options) { opts.TopK = &topK }, } } // WithScoreThreshold wraps the score threshold option. func WithScoreThreshold(threshold float64) Option { return Option{ apply: func(opts *Options) { opts.ScoreThreshold = &threshold }, } } // WithEmbedding wraps the embedder option. func WithEmbedding(emb embedding.Embedder) Option { return Option{ apply: func(opts *Options) { opts.Embedding = emb }, } } // WithDSLInfo wraps the dsl info option. func WithDSLInfo(dsl map[string]any) Option { return Option{ apply: func(opts *Options) { opts.DSLInfo = dsl }, } } // Option is a call-time option for a Retriever. type Option struct { apply func(opts *Options) implSpecificOptFn any } // GetCommonOptions extracts standard [Options] from opts, merging onto base. // Implementors must call this to honour caller-provided options: // // func (r *MyRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { // options := retriever.GetCommonOptions(&retriever.Options{TopK: &r.defaultTopK}, opts...) // // use options.TopK, options.ScoreThreshold, options.Embedding, etc. // } func GetCommonOptions(base *Options, opts ...Option) *Options { if base == nil { base = &Options{} } for i := range opts { if opts[i].apply != nil { opts[i].apply(base) } } return base } // WrapImplSpecificOptFn wraps an implementation-specific option function so it // can be passed alongside standard options. For use by Retriever implementors. func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions extracts implementation-specific options from opts, // merging onto base. Call alongside [GetCommonOptions] inside Retrieve. func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { optFn, ok := opt.implSpecificOptFn.(func(*T)) if ok { optFn(base) } } } return base } ================================================ FILE: components/retriever/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retriever import ( "testing" "github.com/smartystreets/goconvey/convey" "github.com/cloudwego/eino/internal/mock/components/embedding" ) func TestOptions(t *testing.T) { convey.Convey("test options", t, func() { var ( index = "index" topK = 2 scoreThreshold = 4.0 subIndex = "sub_index" dslInfo = map[string]any{"dsl": "dsl"} e = &embedding.MockEmbedder{} defaultTopK = 1 ) opts := GetCommonOptions( &Options{ TopK: &defaultTopK, }, WithIndex(index), WithTopK(topK), WithScoreThreshold(scoreThreshold), WithSubIndex(subIndex), WithDSLInfo(dslInfo), WithEmbedding(e), ) convey.So(opts, convey.ShouldResemble, &Options{ Index: &index, TopK: &topK, ScoreThreshold: &scoreThreshold, SubIndex: &subIndex, DSLInfo: dslInfo, Embedding: e, }) }) } ================================================ FILE: components/tool/callback_extra.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tool import ( "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) // CallbackInput is the input for the tool callback. type CallbackInput struct { // ArgumentsInJSON is the arguments in json format for the tool. ArgumentsInJSON string // Extra is the extra information for the tool. Extra map[string]any } // CallbackOutput is the output for the tool callback. type CallbackOutput struct { // Response is the response for the tool. Response string // ToolOutput is the multimodal output for the tool. Used when the tool returns structured data. ToolOutput *schema.ToolResult // Extra is the extra information for the tool. Extra map[string]any } // ConvCallbackInput converts the callback input to the tool callback input. func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { switch t := src.(type) { case *CallbackInput: return t case string: return &CallbackInput{ArgumentsInJSON: t} case *schema.ToolArgument: return &CallbackInput{ArgumentsInJSON: t.Text} default: return nil } } // ConvCallbackOutput converts the callback output to the tool callback output. func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { switch t := src.(type) { case *CallbackOutput: return t case string: return &CallbackOutput{Response: t} case *schema.ToolResult: return &CallbackOutput{ToolOutput: t} default: return nil } } ================================================ FILE: components/tool/callback_extra_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tool import ( "testing" "github.com/stretchr/testify/assert" ) func TestConvCallbackInput(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) assert.NotNil(t, ConvCallbackInput("asd")) assert.Nil(t, ConvCallbackInput(123)) assert.Nil(t, ConvCallbackInput(nil)) } func TestConvCallbackOutput(t *testing.T) { assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) assert.NotNil(t, ConvCallbackOutput("asd")) assert.Nil(t, ConvCallbackOutput(123)) assert.Nil(t, ConvCallbackOutput(nil)) } ================================================ FILE: components/tool/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package tool defines the tool component interfaces that allow language models // to invoke external capabilities, and helpers for interrupt/resume within tools. // // # Interface Hierarchy // // BaseTool — Info() only; for passing tool metadata to a ChatModel // ├── InvokableTool — standard: args as JSON string, returns string // ├── StreamableTool — standard streaming: args as JSON string, returns StreamReader[string] // ├── EnhancedInvokableTool — multimodal: args as *schema.ToolArgument, returns *schema.ToolResult // └── EnhancedStreamableTool— multimodal streaming // // # Choosing an Interface // // Implement [InvokableTool] for most tools — arguments arrive as a JSON string // automatically decoded from the model's tool call, and the result is a string // sent back to the model. // // Implement [EnhancedInvokableTool] when the tool needs to return structured // multimodal content (images, audio, files) rather than plain text. When a // tool implements both a standard and an enhanced interface, ToolsNode // prioritises the enhanced interface. // // # Creating Tools // // The [utils] sub-package provides constructors that eliminate boilerplate: // - [utils.InferTool] / [utils.InferStreamTool] — infer parameter schema from Go struct tags // - [utils.NewTool] / [utils.NewStreamTool] — manual ToolInfo + typed function // // # Interrupt / Resume // // Tools can pause execution and wait for external input using [Interrupt], // [StatefulInterrupt], and [CompositeInterrupt]. Use [GetInterruptState] and // [GetResumeContext] inside the tool to distinguish first-run from resumed-run. // // See https://www.cloudwego.io/docs/eino/core_modules/components/tools_node_guide/ // See https://www.cloudwego.io/docs/eino/core_modules/components/tools_node_guide/how_to_create_a_tool/ package tool ================================================ FILE: components/tool/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tool import ( "context" "github.com/cloudwego/eino/schema" ) // BaseTool provides the metadata that a ChatModel uses to decide whether and // how to call a tool. Info returns a [schema.ToolInfo] containing the tool // name, description, and parameter JSON schema. // // BaseTool alone is sufficient when passing tool definitions to a ChatModel // via WithTools — the model only needs the schema to generate tool calls. // To also execute the tool, implement [InvokableTool] or [StreamableTool]. type BaseTool interface { Info(ctx context.Context) (*schema.ToolInfo, error) } // InvokableTool is a tool that can be executed by ToolsNode. // // InvokableRun receives the model's tool call arguments as a JSON-encoded // string and returns a plain string result that is sent back to the model as // a tool message. The framework handles JSON decoding automatically when using // the [utils.InferTool] or [utils.NewTool] constructors. type InvokableTool interface { BaseTool // InvokableRun executes the tool with arguments encoded as a JSON string. InvokableRun(ctx context.Context, argumentsInJSON string, opts ...Option) (string, error) } // StreamableTool is a streaming variant of [InvokableTool]. // // StreamableRun returns a [schema.StreamReader] that yields string chunks // incrementally. The caller (ToolsNode) is responsible for closing the reader. type StreamableTool interface { BaseTool StreamableRun(ctx context.Context, argumentsInJSON string, opts ...Option) (*schema.StreamReader[string], error) } // EnhancedInvokableTool is a tool that returns structured multimodal results. // // Unlike [InvokableTool], arguments arrive as a [schema.ToolArgument] (not a // raw JSON string) and the result is a [schema.ToolResult] which can carry // text, images, audio, video, and file content. // // When a tool implements both a standard and an enhanced interface, ToolsNode // prioritises the enhanced interface. type EnhancedInvokableTool interface { BaseTool InvokableRun(ctx context.Context, toolArgument *schema.ToolArgument, opts ...Option) (*schema.ToolResult, error) } // EnhancedStreamableTool is the streaming variant of [EnhancedInvokableTool]. // // It streams [schema.ToolResult] chunks, enabling incremental multimodal // output. The caller is responsible for closing the returned [schema.StreamReader]. type EnhancedStreamableTool interface { BaseTool StreamableRun(ctx context.Context, toolArgument *schema.ToolArgument, opts ...Option) (*schema.StreamReader[*schema.ToolResult], error) } ================================================ FILE: components/tool/interrupt.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tool import ( "context" "errors" "fmt" "github.com/cloudwego/eino/internal/core" ) // Interrupt pauses tool execution and signals the orchestration layer to checkpoint. // The tool can be resumed later with optional data. // // Parameters: // - ctx: The context passed to InvokableRun/StreamableRun // - info: User-facing information about why the tool is interrupting (e.g., "needs user confirmation") // // Returns an error that should be returned from InvokableRun/StreamableRun. // // Example: // // func (t *MyTool) InvokableRun(ctx context.Context, args string, opts ...Option) (string, error) { // if needsConfirmation(args) { // return "", tool.Interrupt(ctx, "Please confirm this action") // } // return doWork(args), nil // } func Interrupt(ctx context.Context, info any) error { is, err := core.Interrupt(ctx, info, nil, nil) if err != nil { return err } return is } // StatefulInterrupt pauses tool execution with state preservation. // Use this when the tool has internal state that must be restored on resume. // // Parameters: // - ctx: The context passed to InvokableRun/StreamableRun // - info: User-facing information about the interrupt // - state: Internal state to persist (must be gob-serializable) // // Example: // // func (t *MyTool) InvokableRun(ctx context.Context, args string, opts ...Option) (string, error) { // wasInterrupted, hasState, state := tool.GetInterruptState[MyState](ctx) // if !wasInterrupted { // // First run - interrupt with state // return "", tool.StatefulInterrupt(ctx, "processing", MyState{Step: 1}) // } // // Resumed - continue from saved state // return continueFrom(state), nil // } func StatefulInterrupt(ctx context.Context, info any, state any) error { is, err := core.Interrupt(ctx, info, state, nil) if err != nil { return err } return is } // CompositeInterrupt creates an interrupt that aggregates multiple sub-interrupts. // Use this when a tool internally executes a graph or other interruptible components. // // Parameters: // - ctx: The context passed to InvokableRun/StreamableRun // - info: User-facing information for this tool's interrupt // - state: Internal state to persist for this tool // - errs: Interrupt errors from sub-components (graphs, other tools, etc.) // // Example: // // func (t *MyTool) InvokableRun(ctx context.Context, args string, opts ...Option) (string, error) { // result, err := t.internalGraph.Invoke(ctx, input) // if err != nil { // if _, ok := tool.IsInterruptError(err); ok { // return "", tool.CompositeInterrupt(ctx, "graph interrupted", myState, err) // } // return "", err // } // return result, nil // } func CompositeInterrupt(ctx context.Context, info any, state any, errs ...error) error { if len(errs) == 0 { return StatefulInterrupt(ctx, info, state) } var cErrs []*core.InterruptSignal for _, err := range errs { ire := &core.InterruptSignal{} if errors.As(err, &ire) { cErrs = append(cErrs, ire) continue } var provider core.InterruptContextsProvider if errors.As(err, &provider) { is := core.FromInterruptContexts(provider.GetInterruptContexts()) if is != nil { cErrs = append(cErrs, is) } continue } return fmt.Errorf("composite interrupt but one of the sub error is not interrupt error: %w", err) } is, err := core.Interrupt(ctx, info, state, cErrs) if err != nil { return err } return is } // GetInterruptState checks if the tool was previously interrupted and retrieves saved state. // // Returns: // - wasInterrupted: true if this tool was part of a previous interruption // - hasState: true if state was saved and successfully cast to type T // - state: the saved state (zero value if hasState is false) // // Example: // // func (t *MyTool) InvokableRun(ctx context.Context, args string, opts ...Option) (string, error) { // wasInterrupted, hasState, state := tool.GetInterruptState[MyState](ctx) // if wasInterrupted && hasState { // // Continue from saved state // return continueFrom(state), nil // } // // First run // return "", tool.StatefulInterrupt(ctx, "need input", MyState{Step: 1}) // } func GetInterruptState[T any](ctx context.Context) (wasInterrupted bool, hasState bool, state T) { return core.GetInterruptState[T](ctx) } // GetResumeContext checks if this tool is the explicit target of a resume operation. // // Returns: // - isResumeTarget: true if this tool was explicitly targeted for resume // - hasData: true if resume data was provided // - data: the resume data (zero value if hasData is false) // // Use this to differentiate between: // - Being resumed as the target (should proceed with work) // - Being re-executed because a sibling was resumed (should re-interrupt) // // Example: // // func (t *MyTool) InvokableRun(ctx context.Context, args string, opts ...Option) (string, error) { // wasInterrupted, _, _ := tool.GetInterruptState[any](ctx) // if !wasInterrupted { // return "", tool.Interrupt(ctx, "need confirmation") // } // // isTarget, hasData, data := tool.GetResumeContext[string](ctx) // if !isTarget { // // Not our turn - re-interrupt // return "", tool.Interrupt(ctx, nil) // } // if hasData { // return data, nil // } // return "default result", nil // } func GetResumeContext[T any](ctx context.Context) (isResumeTarget bool, hasData bool, data T) { return core.GetResumeContext[T](ctx) } ================================================ FILE: components/tool/interrupt_test.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tool import ( "context" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/internal/core" ) func TestInterrupt(t *testing.T) { ctx := context.Background() t.Run("basic interrupt", func(t *testing.T) { err := Interrupt(ctx, "test info") assert.Error(t, err) var signal *core.InterruptSignal assert.True(t, errors.As(err, &signal)) assert.Equal(t, "test info", signal.Info) assert.True(t, signal.IsRootCause) }) } func TestStatefulInterrupt(t *testing.T) { ctx := context.Background() t.Run("stateful interrupt", func(t *testing.T) { type myState struct { Value int } state := &myState{Value: 42} err := StatefulInterrupt(ctx, "test info", state) assert.Error(t, err) var signal *core.InterruptSignal assert.True(t, errors.As(err, &signal)) assert.Equal(t, "test info", signal.Info) assert.Equal(t, state, signal.State) assert.True(t, signal.IsRootCause) }) } func TestCompositeInterrupt(t *testing.T) { ctx := context.Background() t.Run("no sub errors falls back to StatefulInterrupt", func(t *testing.T) { err := CompositeInterrupt(ctx, "composite info", "my state") assert.Error(t, err) var signal *core.InterruptSignal assert.True(t, errors.As(err, &signal)) assert.Equal(t, "composite info", signal.Info) assert.Equal(t, "my state", signal.State) assert.True(t, signal.IsRootCause) assert.Empty(t, signal.Subs) }) t.Run("with InterruptSignal sub error", func(t *testing.T) { subSignal, _ := core.Interrupt(ctx, "sub info", "sub state", nil) err := CompositeInterrupt(ctx, "composite info", "my state", subSignal) assert.Error(t, err) var signal *core.InterruptSignal assert.True(t, errors.As(err, &signal)) assert.Equal(t, "composite info", signal.Info) assert.Equal(t, "my state", signal.State) assert.Len(t, signal.Subs, 1) assert.Equal(t, "sub info", signal.Subs[0].Info) }) t.Run("with non-interrupt error returns error", func(t *testing.T) { nonInterruptErr := errors.New("regular error") err := CompositeInterrupt(ctx, "composite info", "my state", nonInterruptErr) assert.Error(t, err) assert.Contains(t, err.Error(), "composite interrupt but one of the sub error is not interrupt error") var signal *core.InterruptSignal assert.False(t, errors.As(err, &signal)) }) t.Run("with multiple sub errors", func(t *testing.T) { subSignal1, _ := core.Interrupt(ctx, "sub1 info", nil, nil) subSignal2, _ := core.Interrupt(ctx, "sub2 info", nil, nil) err := CompositeInterrupt(ctx, "composite info", nil, subSignal1, subSignal2) assert.Error(t, err) var signal *core.InterruptSignal assert.True(t, errors.As(err, &signal)) assert.Len(t, signal.Subs, 2) }) } func TestGetInterruptState(t *testing.T) { t.Run("not interrupted returns false", func(t *testing.T) { ctx := context.Background() wasInterrupted, hasState, state := GetInterruptState[string](ctx) assert.False(t, wasInterrupted) assert.False(t, hasState) assert.Empty(t, state) }) } func TestGetResumeContext(t *testing.T) { t.Run("not resume target returns false", func(t *testing.T) { ctx := context.Background() isResumeTarget, hasData, data := GetResumeContext[string](ctx) assert.False(t, isResumeTarget) assert.False(t, hasData) assert.Empty(t, data) }) } ================================================ FILE: components/tool/option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tool // Option defines call option for InvokableTool or StreamableTool component, which is part of component interface signature. // Each tool implementation could define its own options struct and option funcs within its own package, // then wrap the impl specific option funcs into this type, before passing to InvokableRun or StreamableRun. type Option struct { implSpecificOptFn any } // WrapImplSpecificOptFn wraps the impl specific option functions into Option type. // T: the type of the impl specific options struct. // Tool implementations are required to use this function to convert its own option functions into the unified Option type. // For example, if the tool defines its own options struct: // // type customOptions struct { // conf string // } // // Then the tool needs to provide an option function as such: // // func WithConf(conf string) Option { // return WrapImplSpecificOptFn(func(o *customOptions) { // o.conf = conf // } // } // // . func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions provides tool author the ability to extract their own custom options from the unified Option type. // T: the type of the impl specific options struct. // This function should be used within the tool implementation's InvokableRun or StreamableRun functions. // It is recommended to provide a base T as the first argument, within which the tool author can provide default values for the impl specific options. // eg. // // type customOptions struct { // conf string // } // defaultOptions := &customOptions{} // // customOptions := tool.GetImplSpecificOptions(defaultOptions, opts...) func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { optFn, ok := opt.implSpecificOptFn.(func(*T)) if ok { optFn(base) } } } return base } ================================================ FILE: components/tool/option_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tool import ( "testing" "github.com/smartystreets/goconvey/convey" ) func TestImplSpecificOpts(t *testing.T) { convey.Convey("TestImplSpecificOpts", t, func() { type implSpecificOptions struct { conf string index int } withConf := func(conf string) func(o *implSpecificOptions) { return func(o *implSpecificOptions) { o.conf = conf } } withIndex := func(index int) func(o *implSpecificOptions) { return func(o *implSpecificOptions) { o.index = index } } toolOption1 := WrapImplSpecificOptFn(withConf("test_conf")) toolOption2 := WrapImplSpecificOptFn(withIndex(1)) implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, toolOption1, toolOption2) convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ conf: "test_conf", index: 1, }) }) } ================================================ FILE: components/tool/utils/common.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "github.com/bytedance/sonic" ) func marshalString(resp any) (string, error) { if rs, ok := resp.(string); ok { return rs, nil } return sonic.MarshalString(resp) } ================================================ FILE: components/tool/utils/common_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "fmt" "testing" "github.com/stretchr/testify/assert" ) func TestMarshalString(t *testing.T) { tests := []struct { name string input interface{} expected string hasError bool }{ { name: "string input should return as-is", input: "hello world", expected: "hello world", hasError: false, }, { name: "empty string should return empty", input: "", expected: "", hasError: false, }, { name: "string with special characters", input: "hello\nworld\t\"test\"", expected: "hello\nworld\t\"test\"", hasError: false, }, { name: "string with unicode", input: "你好世界", expected: "你好世界", hasError: false, }, { name: "integer should be marshaled to JSON", input: 42, expected: "42", hasError: false, }, { name: "float should be marshaled to JSON", input: 3.14, expected: "3.14", hasError: false, }, { name: "boolean true should be marshaled to JSON", input: true, expected: "true", hasError: false, }, { name: "boolean false should be marshaled to JSON", input: false, expected: "false", hasError: false, }, { name: "nil should be marshaled to JSON null", input: nil, expected: "null", hasError: false, }, { name: "slice should be marshaled to JSON array", input: []int{1, 2, 3}, expected: "[1,2,3]", hasError: false, }, { name: "empty slice should be marshaled to JSON empty array", input: []int{}, expected: "[]", hasError: false, }, { name: "empty map should be marshaled to JSON empty object", input: map[string]int{}, expected: "{}", hasError: false, }, { name: "struct should be marshaled to JSON", input: struct{ Name string }{Name: "test"}, expected: `{"Name":"test"}`, hasError: false, }, { name: "pointer to string should be handled as non-string", input: func() *string { s := "test"; return &s }(), expected: `"test"`, hasError: false, }, { name: "interface{} containing string should return as-is", input: interface{}("test string"), expected: "test string", hasError: false, }, { name: "interface{} containing int should be marshaled", input: interface{}(123), expected: "123", hasError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := marshalString(tt.input) if tt.hasError { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.expected, result) } }) } } func TestMarshalStringEdgeCases(t *testing.T) { t.Run("complex nested structure", func(t *testing.T) { complex := map[string]interface{}{ "string": "value", "number": 42, "nested": map[string]interface{}{ "array": []string{"a", "b", "c"}, "bool": true, }, } result, err := marshalString(complex) assert.NoError(t, err) assert.Contains(t, result, `"string":"value"`) assert.Contains(t, result, `"number":42`) assert.Contains(t, result, `"nested"`) }) t.Run("string type assertion priority", func(t *testing.T) { // Test that string type assertion has priority over JSON marshaling var input interface{} = "direct string" result, err := marshalString(input) assert.NoError(t, err) assert.Equal(t, "direct string", result) // Verify it's not JSON encoded assert.NotEqual(t, `"direct string"`, result) }) } func TestMarshalStringConsistency(t *testing.T) { t.Run("string vs JSON marshaling difference", func(t *testing.T) { input := `{"key": "value"}` // Direct string should return as-is result, err := marshalString(input) assert.NoError(t, err) assert.Equal(t, input, result) // Should not be double-encoded assert.NotEqual(t, fmt.Sprintf(`"%s"`, input), result) }) } ================================================ FILE: components/tool/utils/create_options.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "context" "reflect" "github.com/eino-contrib/jsonschema" ) // UnmarshalArguments is the function type for unmarshalling the arguments. type UnmarshalArguments func(ctx context.Context, arguments string) (any, error) // MarshalOutput is the function type for marshalling the output. type MarshalOutput func(ctx context.Context, output any) (string, error) type toolOptions struct { um UnmarshalArguments m MarshalOutput scModifier SchemaModifierFn } // Option is the option func for the tool. type Option func(o *toolOptions) // WithUnmarshalArguments wraps the unmarshal arguments option. // when you want to unmarshal the arguments by yourself, you can use this option. func WithUnmarshalArguments(um UnmarshalArguments) Option { return func(o *toolOptions) { o.um = um } } // WithMarshalOutput wraps the marshal output option. // when you want to marshal the output by yourself, you can use this option. func WithMarshalOutput(m MarshalOutput) Option { return func(o *toolOptions) { o.m = m } } // SchemaModifierFn is the schema modifier function for inferring tool parameter from tagged go struct. // Within this function, end-user can parse custom go struct tags into corresponding json schema field. // Parameters: // 1. jsonTagName: the name defined in the json tag. Specifically, the last 'jsonTagName' visited is fixed to be '_root', which represents the entire go struct. Also, for array field, both the field itself and the element within the array will trigger this function. // 2. t: the type of current schema, usually the field type of the go struct. // 3. tag: the struct tag of current schema, usually the field tag of the go struct. Note that the element within an array field will use the same go struct tag as the array field itself. // 4. schema: the current json schema object to be modified. type SchemaModifierFn func(jsonTagName string, t reflect.Type, tag reflect.StructTag, schema *jsonschema.Schema) // WithSchemaModifier sets a user-defined schema modifier for inferring tool parameter from tagged go struct. func WithSchemaModifier(modifier SchemaModifierFn) Option { return func(o *toolOptions) { o.scModifier = modifier } } func getToolOptions(opt ...Option) *toolOptions { opts := &toolOptions{ um: nil, m: nil, } for _, o := range opt { o(opts) } return opts } ================================================ FILE: components/tool/utils/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package utils provides constructors for building tool implementations without // writing boilerplate JSON serialization code. // // # Choosing a Constructor // // There are two main strategies: // // 1. Infer from struct tags (recommended): [InferTool], [InferStreamTool], // [InferEnhancedTool], [InferEnhancedStreamTool]. // The parameter JSON schema is derived automatically from the input struct's // field names and tags. Requires a typed input struct. // // 2. Manual ToolInfo: [NewTool], [NewStreamTool], [NewEnhancedTool], // [NewEnhancedStreamTool]. // You supply a [schema.ToolInfo] directly. Useful when the schema cannot // be expressed as a Go struct, or must be dynamically constructed. // // # Struct Tag Convention // // InferTool and friends use the following tags on the input struct fields: // // type Input struct { // Query string `json:"query" jsonschema:"required" jsonschema_description:"The search query"` // MaxItems int `json:"max_items" jsonschema_description:"Maximum results to return"` // } // // Key rules: // - Use a separate jsonschema_description tag for field descriptions — // embedding descriptions inside the jsonschema tag causes comma-parsing // issues. // - Use jsonschema:"required" to mark mandatory parameters. // - The json tag controls the parameter name visible to the model. // // # Schema Utilities // // [GoStruct2ToolInfo] and [GoStruct2ParamsOneOf] convert a Go struct to schema // types without creating a tool — useful for ChatModel structured output via // ResponseFormat or BindTools. // // See https://www.cloudwego.io/docs/eino/core_modules/components/tools_node_guide/how_to_create_a_tool/ package utils ================================================ FILE: components/tool/utils/error_handler.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "context" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // ErrorHandler converts a tool error into a string response. type ErrorHandler func(context.Context, error) string // WrapToolWithErrorHandler wraps any BaseTool with custom error handling. // This function detects the tool type (InvokableTool, StreamableTool, or both) // and applies the appropriate error handling wrapper. // When the wrapped tool returns an error, the error handler function 'h' will be called // to convert the error into a string result, and no error will be returned from the wrapper. // // Parameters: // - t: The original BaseTool to be wrapped // - h: A function that converts an error to a string // // Returns: // - A wrapped BaseTool that handles errors internally based on its capabilities func WrapToolWithErrorHandler(t tool.BaseTool, h ErrorHandler) tool.BaseTool { ih := &infoHelper{info: t.Info} var s tool.StreamableTool if st, ok := t.(tool.StreamableTool); ok { s = st } if it, ok := t.(tool.InvokableTool); ok { if s == nil { return WrapInvokableToolWithErrorHandler(it, h) } else { return &combinedErrorWrapper{ infoHelper: ih, errorHelper: &errorHelper{ i: it.InvokableRun, h: h, }, streamErrorHelper: &streamErrorHelper{ s: s.StreamableRun, h: h, }, } } } if s != nil { return WrapStreamableToolWithErrorHandler(s, h) } return t } // WrapInvokableToolWithErrorHandler wraps an InvokableTool with custom error handling. // When the wrapped tool returns an error, the error handler function 'h' will be called // to convert the error into a string result, and no error will be returned from the wrapper. // // Parameters: // - tool: The original InvokableTool to be wrapped // - h: A function that converts an error to a string // // Returns: // - A wrapped InvokableTool that handles errors internally func WrapInvokableToolWithErrorHandler(t tool.InvokableTool, h ErrorHandler) tool.InvokableTool { return &errorWrapper{ infoHelper: &infoHelper{info: t.Info}, errorHelper: &errorHelper{ i: t.InvokableRun, h: h, }, } } // WrapStreamableToolWithErrorHandler wraps a StreamableTool with custom error handling. // When the wrapped tool returns an error, the error handler function 'h' will be called // to convert the error into a string result, which will be returned as a single-item stream, // and no error will be returned from the wrapper. // // Parameters: // - tool: The original StreamableTool to be wrapped // - h: A function that converts an error to a string // // Returns: // - A wrapped StreamableTool that handles errors internally func WrapStreamableToolWithErrorHandler(t tool.StreamableTool, h ErrorHandler) tool.StreamableTool { return &streamErrorWrapper{ infoHelper: &infoHelper{info: t.Info}, streamErrorHelper: &streamErrorHelper{ s: t.StreamableRun, h: h, }, } } type errorWrapper struct { *infoHelper *errorHelper } type streamErrorWrapper struct { *infoHelper *streamErrorHelper } type combinedErrorWrapper struct { *infoHelper *errorHelper *streamErrorHelper } type infoHelper struct { info func(ctx context.Context) (*schema.ToolInfo, error) } func (i *infoHelper) Info(ctx context.Context) (*schema.ToolInfo, error) { return i.info(ctx) } type errorHelper struct { i func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) h ErrorHandler } func (s *errorHelper) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { result, err := s.i(ctx, argumentsInJSON, opts...) if _, ok := compose.IsInterruptRerunError(err); ok { return result, err } if err != nil { return s.h(ctx, err), nil } return result, nil } type streamErrorHelper struct { s func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) h ErrorHandler } func (s *streamErrorHelper) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { result, err := s.s(ctx, argumentsInJSON, opts...) if _, ok := compose.IsInterruptRerunError(err); ok { return result, err } if err != nil { return schema.StreamReaderFromArray([]string{s.h(ctx, err)}), nil } return result, nil } ================================================ FILE: components/tool/utils/error_handler_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "context" "errors" "io" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) type testErrorTool struct{} func (t *testErrorTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return nil, nil } func (t *testErrorTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { return "", errors.New("test error") } func (t *testErrorTool) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { return nil, errors.New("test stream error") } func TestErrorWrapper(t *testing.T) { ctx := context.Background() nt := WrapToolWithErrorHandler(&testErrorTool{}, func(_ context.Context, err error) string { return err.Error() }) result, err := nt.(tool.InvokableTool).InvokableRun(ctx, "") assert.NoError(t, err) assert.Equal(t, "test error", result) streamResult, err := nt.(tool.StreamableTool).StreamableRun(ctx, "") assert.NoError(t, err) chunk, err := streamResult.Recv() assert.NoError(t, err) assert.Equal(t, "test stream error", chunk) _, err = streamResult.Recv() assert.True(t, errors.Is(err, io.EOF)) wrappedTool := WrapInvokableToolWithErrorHandler(&testErrorTool{}, func(_ context.Context, err error) string { return err.Error() }) result, err = wrappedTool.InvokableRun(ctx, "") assert.NoError(t, err) assert.Equal(t, "test error", result) wrappedStreamTool := WrapStreamableToolWithErrorHandler(&testErrorTool{}, func(_ context.Context, err error) string { return err.Error() }) streamResult, err = wrappedStreamTool.StreamableRun(ctx, "") assert.NoError(t, err) chunk, err = streamResult.Recv() assert.NoError(t, err) assert.Equal(t, "test stream error", chunk) _, err = streamResult.Recv() assert.True(t, errors.Is(err, io.EOF)) } ================================================ FILE: components/tool/utils/invokable_func.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "context" "fmt" "strings" "github.com/bytedance/sonic" "github.com/eino-contrib/jsonschema" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) // InvokeFunc is the function type for the tool. type InvokeFunc[T, D any] func(ctx context.Context, input T) (output D, err error) // OptionableInvokeFunc is the function type for the tool with tool option. type OptionableInvokeFunc[T, D any] func(ctx context.Context, input T, opts ...tool.Option) (output D, err error) // InferTool creates an [tool.InvokableTool] by inferring the parameter JSON // schema from the fields and tags of the input type T. // // The tool automatically JSON-decodes the model's argument string into T before // calling fn, and JSON-encodes the D return value into the result string. // // Use [WithSchemaModifier] in opts to customise how struct tags are mapped to // JSON schema fields. func InferTool[T, D any](toolName, toolDesc string, i InvokeFunc[T, D], opts ...Option) (tool.InvokableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return NewTool(ti, i, opts...), nil } // InferOptionableTool is like [InferTool] but the function also receives // [tool.Option] values passed by ToolsNode at call time. func InferOptionableTool[T, D any](toolName, toolDesc string, i OptionableInvokeFunc[T, D], opts ...Option) (tool.InvokableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return newOptionableTool(ti, i, opts...), nil } // EnhancedInvokeFunc is the function type for the enhanced tool. type EnhancedInvokeFunc[T any] func(ctx context.Context, input T) (output *schema.ToolResult, err error) // OptionableEnhancedInvokeFunc is the function type for the enhanced tool with tool option. type OptionableEnhancedInvokeFunc[T any] func(ctx context.Context, input T, opts ...tool.Option) (output *schema.ToolResult, err error) // InferEnhancedTool creates an [tool.EnhancedInvokableTool] by inferring the // parameter JSON schema from type T. The function returns a [schema.ToolResult] // for multimodal output (text, images, audio, video, files). func InferEnhancedTool[T any](toolName, toolDesc string, i EnhancedInvokeFunc[T], opts ...Option) (tool.EnhancedInvokableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return NewEnhancedTool(ti, i, opts...), nil } // InferOptionableEnhancedTool creates an EnhancedInvokableTool from a given function by inferring the ToolInfo from the function's request parameters, with tool option. func InferOptionableEnhancedTool[T any](toolName, toolDesc string, i OptionableEnhancedInvokeFunc[T], opts ...Option) (tool.EnhancedInvokableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return newOptionableEnhancedTool(ti, i, opts...), nil } // GoStruct2ParamsOneOf converts a Go struct's fields and tags into a // [schema.ParamsOneOf] (JSON Schema 2020-12). Useful for ChatModel structured // output via ResponseFormat without creating a full tool. func GoStruct2ParamsOneOf[T any](opts ...Option) (*schema.ParamsOneOf, error) { return goStruct2ParamsOneOf[T](opts...) } // GoStruct2ToolInfo converts a Go struct into a [schema.ToolInfo]. Useful for // binding a typed schema to a ChatModel via BindTools for structured output, // when you do not need a full executable tool. func GoStruct2ToolInfo[T any](toolName, toolDesc string, opts ...Option) (*schema.ToolInfo, error) { return goStruct2ToolInfo[T](toolName, toolDesc, opts...) } func goStruct2ToolInfo[T any](toolName, toolDesc string, opts ...Option) (*schema.ToolInfo, error) { paramsOneOf, err := goStruct2ParamsOneOf[T](opts...) if err != nil { return nil, err } return &schema.ToolInfo{ Name: toolName, Desc: toolDesc, ParamsOneOf: paramsOneOf, }, nil } func goStruct2ParamsOneOf[T any](opts ...Option) (*schema.ParamsOneOf, error) { options := getToolOptions(opts...) r := &jsonschema.Reflector{ Anonymous: true, DoNotReference: true, SchemaModifier: jsonschema.SchemaModifierFn(options.scModifier), } js := r.Reflect(generic.NewInstance[T]()) js.Version = "" paramsOneOf := schema.NewParamsOneOfByJSONSchema(js) return paramsOneOf, nil } // NewTool creates an [tool.InvokableTool] from an explicit [schema.ToolInfo] // and a typed function. Use this when the schema cannot be inferred from struct // tags (e.g. dynamic or complex parameter schemas). // // Note: you are responsible for keeping desc.ParamsOneOf consistent with the // actual fields of T — there is no compile-time check. func NewTool[T, D any](desc *schema.ToolInfo, i InvokeFunc[T, D], opts ...Option) tool.InvokableTool { return newOptionableTool(desc, func(ctx context.Context, input T, _ ...tool.Option) (D, error) { return i(ctx, input) }, opts...) } func newOptionableTool[T, D any](desc *schema.ToolInfo, i OptionableInvokeFunc[T, D], opts ...Option) tool.InvokableTool { to := getToolOptions(opts...) return &invokableTool[T, D]{ info: desc, um: to.um, m: to.m, Fn: i, } } type invokableTool[T, D any] struct { info *schema.ToolInfo um UnmarshalArguments m MarshalOutput Fn OptionableInvokeFunc[T, D] } func (i *invokableTool[T, D]) Info(ctx context.Context) (*schema.ToolInfo, error) { return i.info, nil } // InvokableRun invokes the tool with the given arguments. func (i *invokableTool[T, D]) InvokableRun(ctx context.Context, arguments string, opts ...tool.Option) (output string, err error) { var inst T if i.um != nil { var val any val, err = i.um(ctx, arguments) if err != nil { return "", fmt.Errorf("[LocalFunc] failed to unmarshal arguments, toolName=%s, err=%w", i.getToolName(), err) } gt, ok := val.(T) if !ok { return "", fmt.Errorf("[LocalFunc] invalid type, toolName=%s, expected=%T, given=%T", i.getToolName(), inst, val) } inst = gt } else { inst = generic.NewInstance[T]() err = sonic.UnmarshalString(arguments, &inst) if err != nil { return "", fmt.Errorf("[LocalFunc] failed to unmarshal arguments in json, toolName=%s, err=%w", i.getToolName(), err) } } resp, err := i.Fn(ctx, inst, opts...) if err != nil { return "", fmt.Errorf("[LocalFunc] failed to invoke tool, toolName=%s, err=%w", i.getToolName(), err) } if i.m != nil { output, err = i.m(ctx, resp) if err != nil { return "", fmt.Errorf("[LocalFunc] failed to marshal output, toolName=%s, err=%w", i.getToolName(), err) } } else { output, err = marshalString(resp) if err != nil { return "", fmt.Errorf("[LocalFunc] failed to marshal output in json, toolName=%s, err=%w", i.getToolName(), err) } } return output, nil } func (i *invokableTool[T, D]) GetType() string { return snakeToCamel(i.getToolName()) } func (i *invokableTool[T, D]) getToolName() string { if i.info == nil { return "" } return i.info.Name } // snakeToCamel converts a snake_case string to CamelCase. func snakeToCamel(s string) string { if s == "" { return "" } parts := strings.Split(s, "_") for i := 0; i < len(parts); i++ { if len(parts[i]) > 0 { parts[i] = strings.ToUpper(string(parts[i][0])) + strings.ToLower(parts[i][1:]) } } return strings.Join(parts, "") } // NewEnhancedTool creates an [tool.EnhancedInvokableTool] from an explicit // [schema.ToolInfo] and a function that returns [schema.ToolResult]. func NewEnhancedTool[T any](desc *schema.ToolInfo, i EnhancedInvokeFunc[T], opts ...Option) tool.EnhancedInvokableTool { return newOptionableEnhancedTool(desc, func(ctx context.Context, input T, _ ...tool.Option) (*schema.ToolResult, error) { return i(ctx, input) }, opts...) } func newOptionableEnhancedTool[T any](desc *schema.ToolInfo, i OptionableEnhancedInvokeFunc[T], opts ...Option) tool.EnhancedInvokableTool { to := getToolOptions(opts...) return &enhancedInvokableTool[T]{ info: desc, um: to.um, Fn: i, } } type enhancedInvokableTool[T any] struct { info *schema.ToolInfo um UnmarshalArguments Fn OptionableEnhancedInvokeFunc[T] } func (e *enhancedInvokableTool[T]) Info(ctx context.Context) (*schema.ToolInfo, error) { return e.info, nil } func (e *enhancedInvokableTool[T]) InvokableRun(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { var inst T var err error if e.um != nil { var val any val, err = e.um(ctx, toolArgument.Text) if err != nil { return nil, fmt.Errorf("[EnhancedLocalFunc] failed to unmarshal arguments, toolName=%s, err=%w", e.getToolName(), err) } gt, ok := val.(T) if !ok { return nil, fmt.Errorf("[EnhancedLocalFunc] invalid type, toolName=%s, expected=%T, given=%T", e.getToolName(), inst, val) } inst = gt } else { inst = generic.NewInstance[T]() err = sonic.UnmarshalString(toolArgument.Text, &inst) if err != nil { return nil, fmt.Errorf("[EnhancedLocalFunc] failed to unmarshal arguments in json, toolName=%s, err=%w", e.getToolName(), err) } } resp, err := e.Fn(ctx, inst, opts...) if err != nil { return nil, fmt.Errorf("[EnhancedLocalFunc] failed to invoke tool, toolName=%s, err=%w", e.getToolName(), err) } return resp, nil } func (e *enhancedInvokableTool[T]) GetType() string { return snakeToCamel(e.getToolName()) } func (e *enhancedInvokableTool[T]) getToolName() string { if e.info == nil { return "" } return e.info.Name } ================================================ FILE: components/tool/utils/invokable_func_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "context" "encoding/json" "fmt" "testing" "github.com/eino-contrib/jsonschema" "github.com/stretchr/testify/assert" orderedmap "github.com/wk8/go-ordered-map/v2" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) type Job struct { Company string `json:"company" jsonschema:"description=the company where the user works"` Position string `json:"position,omitempty" jsonschema:"description=the position of the user's job"` ServiceLength float32 `json:"service_length,omitempty" jsonschema:"description=the year of user's service"` // 司龄,年 } type Income struct { Source string `json:"source" jsonschema:"description=the source of income"` Amount int `json:"amount" jsonschema:"description=the amount of income"` HasPayTax bool `json:"has_pay_tax" jsonschema:"description=whether the user has paid tax"` Job *Job `json:"job,omitempty" jsonschema:"description=the job of the user when earning this income"` } type User struct { Name string `json:"name" jsonschema:"required,description=the name of the user"` Age int `json:"age" jsonschema:"required,description=the age of the user"` Job *Job `json:"job,omitempty" jsonschema:"description=the job of the user"` Incomes []*Income `json:"incomes" jsonschema:"description=the incomes of the user"` } type UserResult struct { Code int `json:"code"` Msg string `json:"msg"` } var toolInfo = &schema.ToolInfo{ Name: "update_user_info", Desc: "full update user info", ParamsOneOf: schema.NewParamsOneOfByJSONSchema( &jsonschema.Schema{ Type: "object", Required: []string{"name", "age", "incomes"}, AdditionalProperties: jsonschema.FalseSchema, Properties: orderedmap.New[string, *jsonschema.Schema]( orderedmap.WithInitialData( orderedmap.Pair[string, *jsonschema.Schema]{ Key: "name", Value: &jsonschema.Schema{ Type: "string", Description: "the name of the user", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "age", Value: &jsonschema.Schema{ Type: "integer", Description: "the age of the user", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "job", Value: &jsonschema.Schema{ Type: "object", Required: []string{"company"}, AdditionalProperties: jsonschema.FalseSchema, Description: "the job of the user", Properties: orderedmap.New[string, *jsonschema.Schema]( orderedmap.WithInitialData( orderedmap.Pair[string, *jsonschema.Schema]{ Key: "company", Value: &jsonschema.Schema{ Type: "string", Description: "the company where the user works", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "position", Value: &jsonschema.Schema{ Type: "string", Description: "the position of the user's job", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "service_length", Value: &jsonschema.Schema{ Type: "number", Description: "the year of user's service", }, }, ), ), }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "incomes", Value: &jsonschema.Schema{ Type: "array", Description: "the incomes of the user", Items: &jsonschema.Schema{ Type: "object", AdditionalProperties: jsonschema.FalseSchema, Required: []string{"source", "amount", "has_pay_tax"}, Properties: orderedmap.New[string, *jsonschema.Schema]( orderedmap.WithInitialData( orderedmap.Pair[string, *jsonschema.Schema]{ Key: "source", Value: &jsonschema.Schema{ Type: "string", Description: "the source of income", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "amount", Value: &jsonschema.Schema{ Type: "integer", Description: "the amount of income", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "has_pay_tax", Value: &jsonschema.Schema{ Type: "boolean", Description: "whether the user has paid tax", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "job", Value: &jsonschema.Schema{ Type: "object", AdditionalProperties: jsonschema.FalseSchema, Required: []string{"company"}, Description: "the job of the user when earning this income", Properties: orderedmap.New[string, *jsonschema.Schema]( orderedmap.WithInitialData( orderedmap.Pair[string, *jsonschema.Schema]{ Key: "company", Value: &jsonschema.Schema{ Type: "string", Description: "the company where the user works", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "position", Value: &jsonschema.Schema{ Type: "string", Description: "the position of the user's job", }, }, orderedmap.Pair[string, *jsonschema.Schema]{ Key: "service_length", Value: &jsonschema.Schema{ Type: "number", Description: "the year of user's service", }, }, ), ), }, }, ), ), }, }, }, ), ), }), } func updateUserInfo(ctx context.Context, input *User) (output *UserResult, err error) { return &UserResult{ Code: 200, Msg: fmt.Sprintf("update %v success", input.Name), }, nil } type UserInfoOption struct { Field1 string } func WithUserInfoOption(s string) tool.Option { return tool.WrapImplSpecificOptFn(func(t *UserInfoOption) { t.Field1 = s }) } func updateUserInfoWithOption(_ context.Context, input *User, opts ...tool.Option) (output *UserResult, err error) { baseOption := &UserInfoOption{ Field1: "test_origin", } option := tool.GetImplSpecificOptions(baseOption, opts...) return &UserResult{ Code: 200, Msg: option.Field1, }, nil } func TestInferTool(t *testing.T) { t.Run("invoke_infer_tool", func(t *testing.T) { ctx := context.Background() tl, err := InferTool("update_user_info", "full update user info", updateUserInfo) assert.NoError(t, err) info, err := tl.Info(context.Background()) assert.NoError(t, err) actual, err := info.ToJSONSchema() assert.NoError(t, err) actualStr, err := json.Marshal(actual) assert.NoError(t, err) expect, err := toolInfo.ToJSONSchema() assert.NoError(t, err) expectStr, err := json.Marshal(expect) assert.NoError(t, err) assert.Equal(t, string(expectStr), string(actualStr)) content, err := tl.InvokableRun(ctx, `{"name": "bruce lee"}`) assert.NoError(t, err) assert.JSONEq(t, `{"code":200,"msg":"update bruce lee success"}`, content) }) } func TestInferOptionableTool(t *testing.T) { ctx := context.Background() t.Run("invoke_infer_optionable_tool", func(t *testing.T) { tl, err := InferOptionableTool("invoke_infer_optionable_tool", "full update user info", updateUserInfoWithOption) assert.NoError(t, err) content, err := tl.InvokableRun(ctx, `{"name": "bruce lee"}`, WithUserInfoOption("hello world")) assert.NoError(t, err) assert.JSONEq(t, `{"code":200,"msg":"hello world"}`, content) }) } func TestNewTool(t *testing.T) { ctx := context.Background() type Input struct { Name string `json:"name"` } type Output struct { Name string `json:"name"` } t.Run("struct_input_struct_output", func(t *testing.T) { tl := NewTool[Input, Output](nil, func(ctx context.Context, input Input) (output Output, err error) { return Output{ Name: input.Name, }, nil }) _, err := tl.InvokableRun(ctx, `{"name":"test"}`) assert.Nil(t, err) }) t.Run("pointer_input_pointer_output", func(t *testing.T) { tl := NewTool[*Input, *Output](nil, func(ctx context.Context, input *Input) (output *Output, err error) { return &Output{ Name: input.Name, }, nil }) content, err := tl.InvokableRun(ctx, `{"name":"test"}`) assert.NoError(t, err) assert.Equal(t, `{"name":"test"}`, content) }) t.Run("string_input_int64_output", func(t *testing.T) { tl := NewTool(nil, func(ctx context.Context, input string) (output int64, err error) { return 10, nil }) content, err := tl.InvokableRun(ctx, `100`) // json unmarshal must contains double quote if is not json string. assert.Error(t, err) assert.Equal(t, "", content) }) t.Run("string_pointer_input_int64_pointer_output", func(t *testing.T) { tl := NewTool[*string, *int64](nil, func(ctx context.Context, input *string) (output *int64, err error) { n := int64(10) return &n, nil }) content, err := tl.InvokableRun(ctx, `"100"`) assert.NoError(t, err) assert.Equal(t, `10`, content) }) } func TestSnakeToCamel(t *testing.T) { t.Run("normal_case", func(t *testing.T) { assert.Equal(t, "GoogleSearch3", snakeToCamel("google_search_3")) }) t.Run("empty_case", func(t *testing.T) { assert.Equal(t, "", snakeToCamel("")) }) t.Run("single_word_case", func(t *testing.T) { assert.Equal(t, "Google", snakeToCamel("google")) }) t.Run("upper_case", func(t *testing.T) { assert.Equal(t, "HttpHost", snakeToCamel("_HTTP_HOST_")) }) t.Run("underscore_case", func(t *testing.T) { assert.Equal(t, "", snakeToCamel("_")) }) } type stringAlias string type integerAlias uint32 type floatAlias float64 type boolAlias bool type testEnumStruct struct { Field1 string `json:"field1" jsonschema:"enum=a,enum=b"` Field2 int `json:"field2" jsonschema:"enum=1,enum=2"` Field3 float32 `json:"field3" jsonschema:"enum=1.1,enum=2.2"` Field4 bool `json:"field4" jsonschema:"default=true"` Field5 stringAlias `json:"field5" jsonschema:"enum=a,enum=c"` Field6 integerAlias `json:"field6" jsonschema:"enum=3,enum=4"` Field7 floatAlias `json:"field7" jsonschema:"enum=3.3,enum=4.4"` Field8 boolAlias `json:"field8" jsonschema:"enum=false"` } type testEnumStruct2 struct { Field1 int8 `json:"field1" jsonschema:"enum=1.1"` } type testEnumStruct3 struct { Field1 float64 `json:"field1" jsonschema:"enum=a"` } func TestEnumTag(t *testing.T) { info, err := goStruct2ParamsOneOf[testEnumStruct]() assert.NoError(t, err) s, err := info.ToJSONSchema() assert.NoError(t, err) enum, ok := s.Properties.Get("field1") assert.True(t, ok) assert.Equal(t, []any{"a", "b"}, enum.Enum) enum, ok = s.Properties.Get("field2") assert.True(t, ok) assert.Equal(t, []any{json.Number("1"), json.Number("2")}, enum.Enum) enum, ok = s.Properties.Get("field3") assert.True(t, ok) assert.Equal(t, []any{json.Number("1.1"), json.Number("2.2")}, enum.Enum) enum, ok = s.Properties.Get("field4") assert.True(t, ok) assert.Equal(t, true, enum.Default) enum, ok = s.Properties.Get("field5") assert.True(t, ok) assert.Equal(t, []any{"a", "c"}, enum.Enum) enum, ok = s.Properties.Get("field6") assert.True(t, ok) assert.Equal(t, []any{json.Number("3"), json.Number("4")}, enum.Enum) enum, ok = s.Properties.Get("field7") assert.True(t, ok) assert.Equal(t, []any{json.Number("3.3"), json.Number("4.4")}, enum.Enum) _, err = goStruct2ParamsOneOf[testEnumStruct2]() assert.NoError(t, err) _, err = goStruct2ParamsOneOf[testEnumStruct3]() assert.NoError(t, err) } ================================================ FILE: components/tool/utils/streamable_func.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "context" "fmt" "github.com/bytedance/sonic" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) // StreamFunc is the function type for the streamable tool. type StreamFunc[T, D any] func(ctx context.Context, input T) (output *schema.StreamReader[D], err error) // OptionableStreamFunc is the function type for the streamable tool with tool option. type OptionableStreamFunc[T, D any] func(ctx context.Context, input T, opts ...tool.Option) (output *schema.StreamReader[D], err error) // InferStreamTool creates a [tool.StreamableTool] by inferring the parameter // JSON schema from type T. The function returns a [schema.StreamReader] of D // values which the framework serialises to a string stream. func InferStreamTool[T, D any](toolName, toolDesc string, s StreamFunc[T, D], opts ...Option) (tool.StreamableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return NewStreamTool(ti, s, opts...), nil } // InferOptionableStreamTool is like [InferStreamTool] but the function also // receives [tool.Option] values passed by ToolsNode at call time. func InferOptionableStreamTool[T, D any](toolName, toolDesc string, s OptionableStreamFunc[T, D], opts ...Option) (tool.StreamableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return newOptionableStreamTool(ti, s, opts...), nil } // NewStreamTool creates a [tool.StreamableTool] from an explicit [schema.ToolInfo] // and a typed streaming function. func NewStreamTool[T, D any](desc *schema.ToolInfo, s StreamFunc[T, D], opts ...Option) tool.StreamableTool { return newOptionableStreamTool(desc, func(ctx context.Context, input T, _ ...tool.Option) (output *schema.StreamReader[D], err error) { return s(ctx, input) }, opts...) } func newOptionableStreamTool[T, D any](desc *schema.ToolInfo, s OptionableStreamFunc[T, D], opts ...Option) tool.StreamableTool { to := getToolOptions(opts...) return &streamableTool[T, D]{ info: desc, um: to.um, m: to.m, Fn: s, } } type streamableTool[T, D any] struct { info *schema.ToolInfo um UnmarshalArguments m MarshalOutput Fn OptionableStreamFunc[T, D] } // Info returns the tool info, implement the BaseTool interface. func (s *streamableTool[T, D]) Info(ctx context.Context) (*schema.ToolInfo, error) { return s.info, nil } // StreamableRun invokes the tool with the given arguments, implement the StreamableTool interface. func (s *streamableTool[T, D]) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) ( outStream *schema.StreamReader[string], err error) { var inst T if s.um != nil { var val any val, err = s.um(ctx, argumentsInJSON) if err != nil { return nil, fmt.Errorf("[LocalStreamFunc] failed to unmarshal arguments, toolName=%s, err=%w", s.getToolName(), err) } gt, ok := val.(T) if !ok { return nil, fmt.Errorf("[LocalStreamFunc] type err, toolName=%s, expected=%T, given=%T", s.getToolName(), inst, val) } inst = gt } else { inst = generic.NewInstance[T]() err = sonic.UnmarshalString(argumentsInJSON, &inst) if err != nil { return nil, fmt.Errorf("[LocalStreamFunc] failed to unmarshal arguments in json, toolName=%s, err=%w", s.getToolName(), err) } } streamD, err := s.Fn(ctx, inst, opts...) if err != nil { return nil, err } outStream = schema.StreamReaderWithConvert(streamD, func(d D) (string, error) { var out string var e error if s.m != nil { out, e = s.m(ctx, d) if e != nil { return "", fmt.Errorf("[LocalStreamFunc] failed to marshal output, toolName=%s, err=%w", s.getToolName(), e) } } else { out, e = marshalString(d) if e != nil { return "", fmt.Errorf("[LocalStreamFunc] failed to marshal output in json, toolName=%s, err=%w", s.getToolName(), e) } } return out, nil }) return outStream, nil } func (s *streamableTool[T, D]) GetType() string { return snakeToCamel(s.getToolName()) } func (s *streamableTool[T, D]) getToolName() string { if s.info == nil { return "" } return s.info.Name } // EnhancedStreamFunc is the function type for the enhanced streamable tool. type EnhancedStreamFunc[T any] func(ctx context.Context, input T) (output *schema.StreamReader[*schema.ToolResult], err error) // OptionableEnhancedStreamFunc is the function type for the enhanced streamable tool with tool option. type OptionableEnhancedStreamFunc[T any] func(ctx context.Context, input T, opts ...tool.Option) (output *schema.StreamReader[*schema.ToolResult], err error) // InferEnhancedStreamTool creates an [tool.EnhancedStreamableTool] by inferring // the parameter JSON schema from type T. The function streams [schema.ToolResult] // values for multimodal output. func InferEnhancedStreamTool[T any](toolName, toolDesc string, s EnhancedStreamFunc[T], opts ...Option) (tool.EnhancedStreamableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return NewEnhancedStreamTool(ti, s, opts...), nil } // InferOptionableEnhancedStreamTool creates an EnhancedStreamableTool from a given function by inferring the ToolInfo from the function's request parameters, with tool option. func InferOptionableEnhancedStreamTool[T any](toolName, toolDesc string, s OptionableEnhancedStreamFunc[T], opts ...Option) (tool.EnhancedStreamableTool, error) { ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) if err != nil { return nil, err } return newOptionableEnhancedStreamTool(ti, s, opts...), nil } // NewEnhancedStreamTool creates an [tool.EnhancedStreamableTool] from an // explicit [schema.ToolInfo] and a typed streaming function. func NewEnhancedStreamTool[T any](desc *schema.ToolInfo, s EnhancedStreamFunc[T], opts ...Option) tool.EnhancedStreamableTool { return newOptionableEnhancedStreamTool(desc, func(ctx context.Context, input T, _ ...tool.Option) (output *schema.StreamReader[*schema.ToolResult], err error) { return s(ctx, input) }, opts...) } func newOptionableEnhancedStreamTool[T any](desc *schema.ToolInfo, s OptionableEnhancedStreamFunc[T], opts ...Option) tool.EnhancedStreamableTool { to := getToolOptions(opts...) return &enhancedStreamableTool[T]{ info: desc, um: to.um, Fn: s, } } type enhancedStreamableTool[T any] struct { info *schema.ToolInfo um UnmarshalArguments Fn OptionableEnhancedStreamFunc[T] } func (s *enhancedStreamableTool[T]) Info(ctx context.Context) (*schema.ToolInfo, error) { return s.info, nil } func (s *enhancedStreamableTool[T]) StreamableRun(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) ( outStream *schema.StreamReader[*schema.ToolResult], err error) { var inst T if s.um != nil { var val any val, err = s.um(ctx, toolArgument.Text) if err != nil { return nil, fmt.Errorf("[EnhancedLocalStreamFunc] failed to unmarshal arguments, toolName=%s, err=%w", s.getToolName(), err) } gt, ok := val.(T) if !ok { return nil, fmt.Errorf("[EnhancedLocalStreamFunc] type err, toolName=%s, expected=%T, given=%T", s.getToolName(), inst, val) } inst = gt } else { inst = generic.NewInstance[T]() err = sonic.UnmarshalString(toolArgument.Text, &inst) if err != nil { return nil, fmt.Errorf("[EnhancedLocalStreamFunc] failed to unmarshal arguments in json, toolName=%s, err=%w", s.getToolName(), err) } } return s.Fn(ctx, inst, opts...) } func (s *enhancedStreamableTool[T]) GetType() string { return snakeToCamel(s.getToolName()) } func (s *enhancedStreamableTool[T]) getToolName() string { if s.info == nil { return "" } return s.info.Name } ================================================ FILE: components/tool/utils/streamable_func_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "context" "errors" "io" "testing" "github.com/eino-contrib/jsonschema" "github.com/stretchr/testify/assert" orderedmap "github.com/wk8/go-ordered-map/v2" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) func TestNewStreamableTool(t *testing.T) { ctx := context.Background() type Input struct { Name string `json:"name"` } type Output struct { Name string `json:"name"` } t.Run("simple_case", func(t *testing.T) { tl := NewStreamTool[*Input, *Output]( &schema.ToolInfo{ Name: "search_user", Desc: "search user info", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Type: "string", Desc: "user name", }, }), }, func(ctx context.Context, input *Input) (output *schema.StreamReader[*Output], err error) { sr, sw := schema.Pipe[*Output](2) sw.Send(&Output{ Name: input.Name, }, nil) sw.Send(&Output{ Name: "lee", }, nil) sw.Close() return sr, nil }, ) info, err := tl.Info(ctx) assert.NoError(t, err) assert.Equal(t, "search_user", info.Name) js, err := info.ToJSONSchema() assert.NoError(t, err) assert.Equal(t, &jsonschema.Schema{ Type: "object", Properties: orderedmap.New[string, *jsonschema.Schema]( orderedmap.WithInitialData[string, *jsonschema.Schema]( orderedmap.Pair[string, *jsonschema.Schema]{ Key: "name", Value: &jsonschema.Schema{ Type: "string", Description: "user name", }, }, ), ), Required: make([]string, 0), }, js) sr, err := tl.StreamableRun(ctx, `{"name":"xxx"}`) assert.NoError(t, err) defer sr.Close() idx := 0 for { m, err := sr.Recv() if errors.Is(err, io.EOF) { break } assert.NoError(t, err) if idx == 0 { assert.Equal(t, `{"name":"xxx"}`, m) } else { assert.Equal(t, `{"name":"lee"}`, m) } idx++ } assert.Equal(t, 2, idx) }) } type FakeStreamOption struct { Field string } type FakeStreamInferToolInput struct { Field string `json:"field"` } type FakeStreamInferToolOutput struct { Field string `json:"field"` } func FakeWithToolOption(s string) tool.Option { return tool.WrapImplSpecificOptFn(func(t *FakeStreamOption) { t.Field = s }) } func fakeStreamFunc(ctx context.Context, input FakeStreamInferToolInput, opts ...tool.Option) (output *schema.StreamReader[*FakeStreamInferToolOutput], err error) { baseOpt := &FakeStreamOption{ Field: "default_field_value", } option := tool.GetImplSpecificOptions(baseOpt, opts...) return schema.StreamReaderFromArray([]*FakeStreamInferToolOutput{ { Field: option.Field, }, }), nil } func TestInferStreamTool(t *testing.T) { st, err := InferOptionableStreamTool("infer_optionable_stream_tool", "test infer stream tool with option", fakeStreamFunc) assert.Nil(t, err) sr, err := st.StreamableRun(context.Background(), `{"field": "value"}`, FakeWithToolOption("hello world")) assert.Nil(t, err) defer sr.Close() idx := 0 for { m, err := sr.Recv() if errors.Is(err, io.EOF) { break } assert.NoError(t, err) if idx == 0 { assert.JSONEq(t, `{"field":"hello world"}`, m) } } } type EnhancedStreamInput struct { Query string `json:"query" jsonschema:"description=the search query"` } func TestNewEnhancedStreamTool(t *testing.T) { ctx := context.Background() t.Run("simple_case", func(t *testing.T) { tl := NewEnhancedStreamTool[*EnhancedStreamInput]( &schema.ToolInfo{ Name: "enhanced_stream_search", Desc: "search with enhanced stream output", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "query": { Type: "string", Desc: "the search query", }, }), }, func(ctx context.Context, input *EnhancedStreamInput) (*schema.StreamReader[*schema.ToolResult], error) { sr, sw := schema.Pipe[*schema.ToolResult](2) sw.Send(&schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "result for: " + input.Query}, }, }, nil) sw.Send(&schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "more results"}, }, }, nil) sw.Close() return sr, nil }, ) info, err := tl.Info(ctx) assert.NoError(t, err) assert.Equal(t, "enhanced_stream_search", info.Name) sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{Text: `{"query":"test"}`}) assert.NoError(t, err) defer sr.Close() idx := 0 for { m, err := sr.Recv() if errors.Is(err, io.EOF) { break } assert.NoError(t, err) if idx == 0 { assert.Len(t, m.Parts, 1) assert.Equal(t, schema.ToolPartTypeText, m.Parts[0].Type) assert.Equal(t, "result for: test", m.Parts[0].Text) } else { assert.Len(t, m.Parts, 1) assert.Equal(t, "more results", m.Parts[0].Text) } idx++ } assert.Equal(t, 2, idx) }) } type FakeEnhancedStreamOption struct { Prefix string } func FakeWithEnhancedStreamOption(prefix string) tool.Option { return tool.WrapImplSpecificOptFn(func(t *FakeEnhancedStreamOption) { t.Prefix = prefix }) } func fakeEnhancedStreamFunc(ctx context.Context, input EnhancedStreamInput) (*schema.StreamReader[*schema.ToolResult], error) { return schema.StreamReaderFromArray([]*schema.ToolResult{ { Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "result: " + input.Query}, }, }, }), nil } func fakeOptionableEnhancedStreamFunc(ctx context.Context, input EnhancedStreamInput, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { baseOpt := &FakeEnhancedStreamOption{ Prefix: "default", } option := tool.GetImplSpecificOptions(baseOpt, opts...) return schema.StreamReaderFromArray([]*schema.ToolResult{ { Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: option.Prefix + ": " + input.Query}, }, }, }), nil } func TestInferEnhancedStreamTool(t *testing.T) { ctx := context.Background() t.Run("infer_enhanced_stream_tool", func(t *testing.T) { tl, err := InferEnhancedStreamTool("infer_enhanced_stream", "test infer enhanced stream tool", fakeEnhancedStreamFunc) assert.NoError(t, err) info, err := tl.Info(ctx) assert.NoError(t, err) assert.Equal(t, "infer_enhanced_stream", info.Name) sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{Text: `{"query":"hello"}`}) assert.NoError(t, err) defer sr.Close() m, err := sr.Recv() assert.NoError(t, err) assert.Len(t, m.Parts, 1) assert.Equal(t, "result: hello", m.Parts[0].Text) }) } func TestInferOptionableEnhancedStreamTool(t *testing.T) { ctx := context.Background() t.Run("infer_optionable_enhanced_stream_tool", func(t *testing.T) { tl, err := InferOptionableEnhancedStreamTool("infer_optionable_enhanced_stream", "test infer optionable enhanced stream tool", fakeOptionableEnhancedStreamFunc) assert.NoError(t, err) info, err := tl.Info(ctx) assert.NoError(t, err) assert.Equal(t, "infer_optionable_enhanced_stream", info.Name) sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{Text: `{"query":"world"}`}, FakeWithEnhancedStreamOption("custom")) assert.NoError(t, err) defer sr.Close() m, err := sr.Recv() assert.NoError(t, err) assert.Len(t, m.Parts, 1) assert.Equal(t, "custom: world", m.Parts[0].Text) }) t.Run("infer_optionable_enhanced_stream_tool_default_option", func(t *testing.T) { tl, err := InferOptionableEnhancedStreamTool("infer_optionable_enhanced_stream", "test infer optionable enhanced stream tool", fakeOptionableEnhancedStreamFunc) assert.NoError(t, err) sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{Text: `{"query":"test"}`}) assert.NoError(t, err) defer sr.Close() m, err := sr.Recv() assert.NoError(t, err) assert.Len(t, m.Parts, 1) assert.Equal(t, "default: test", m.Parts[0].Text) }) } ================================================ FILE: components/types.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package components defines common interfaces that describe component // types and callback capabilities used across Eino. package components // Typer provides a human-readable type name for a component implementation. // // When implemented, the component's full display name in DevOps tooling // (visual debugger, IDE plugin, dashboards) becomes "{GetType()}{ComponentKind}" // — e.g. "OpenAIChatModel". Use CamelCase naming. // // Also used by [utils.InferTool] and similar constructors to set the display // name of tool instances. type Typer interface { GetType() string } // GetType returns the type name for a component that implements Typer. func GetType(component any) (string, bool) { if typer, ok := component.(Typer); ok { return typer.GetType(), true } return "", false } // Checker controls whether the framework's automatic callback instrumentation // is active for a component. // // When IsCallbacksEnabled returns true, the framework skips its default // OnStart/OnEnd wrapping and trusts the component to invoke callbacks itself // at the correct points. Implement this when your component needs precise // control over callback timing or content — for example, when streaming // requires callbacks to fire mid-stream rather than only at completion. type Checker interface { IsCallbacksEnabled() bool } // IsCallbacksEnabled reports whether a component implements Checker and enables callbacks. func IsCallbacksEnabled(i any) bool { if checker, ok := i.(Checker); ok { return checker.IsCallbacksEnabled() } return false } // Component names representing the different categories of components. type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" // ComponentOfIndexer identifies indexer components. ComponentOfIndexer Component = "Indexer" // ComponentOfRetriever identifies retriever components. ComponentOfRetriever Component = "Retriever" // ComponentOfLoader identifies loader components. ComponentOfLoader Component = "Loader" // ComponentOfTransformer identifies document transformer components. ComponentOfTransformer Component = "DocumentTransformer" // ComponentOfTool identifies tool components. ComponentOfTool Component = "Tool" ) ================================================ FILE: compose/branch.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "reflect" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) // GraphBranchCondition is the condition type for the branch. type GraphBranchCondition[T any] func(ctx context.Context, in T) (endNode string, err error) // StreamGraphBranchCondition is the condition type for the stream branch. type StreamGraphBranchCondition[T any] func(ctx context.Context, in *schema.StreamReader[T]) (endNode string, err error) // GraphMultiBranchCondition is the condition type for the multi choice branch. type GraphMultiBranchCondition[T any] func(ctx context.Context, in T) (endNode map[string]bool, err error) // StreamGraphMultiBranchCondition is the condition type for the stream multi choice branch. type StreamGraphMultiBranchCondition[T any] func(ctx context.Context, in *schema.StreamReader[T]) (endNodes map[string]bool, err error) // GraphBranch is the branch type for the graph. // It is used to determine the next node based on the condition. type GraphBranch struct { invoke func(ctx context.Context, input any) (output []string, err error) collect func(ctx context.Context, input streamReader) (output []string, err error) inputType reflect.Type *genericHelper endNodes map[string]bool idx int // used to distinguish branches in parallel noDataFlow bool } // GetEndNode returns the all end nodes of the branch. func (gb *GraphBranch) GetEndNode() map[string]bool { return gb.endNodes } func newGraphBranch[T any](r *runnablePacker[T, []string, any], endNodes map[string]bool) *GraphBranch { return &GraphBranch{ invoke: func(ctx context.Context, input any) (output []string, err error) { in, ok := input.(T) if !ok { // When a nil is passed as an 'any' type, its original type information is lost, // becoming an untyped nil. This would cause type assertions to fail. // So if the input is nil and the target type T is an interface, we need to explicitly create a nil of type T. if input == nil && generic.TypeOf[T]().Kind() == reflect.Interface { var i T in = i } else { panic(newUnexpectedInputTypeErr(generic.TypeOf[T](), reflect.TypeOf(input))) } } return r.Invoke(ctx, in) }, collect: func(ctx context.Context, input streamReader) (output []string, err error) { in, ok := unpackStreamReader[T](input) if !ok { panic(newUnexpectedInputTypeErr(generic.TypeOf[T](), input.getType())) } return r.Collect(ctx, in) }, inputType: generic.TypeOf[T](), genericHelper: newGenericHelper[T, T](), endNodes: endNodes, } } // NewGraphMultiBranch creates a branch for graphs where a condition selects // multiple end nodes; only keys present in endNodes are allowed. func NewGraphMultiBranch[T any](condition GraphMultiBranchCondition[T], endNodes map[string]bool) *GraphBranch { condRun := func(ctx context.Context, in T, opts ...any) ([]string, error) { ends, err := condition(ctx, in) if err != nil { return nil, err } ret := make([]string, 0, len(ends)) for end := range ends { if !endNodes[end] { return nil, fmt.Errorf("branch invocation returns unintended end node: %s", end) } ret = append(ret, end) } return ret, nil } return newGraphBranch(newRunnablePacker(condRun, nil, nil, nil, false), endNodes) } // NewStreamGraphMultiBranch creates a streaming branch where a condition on // the input stream selects multiple end nodes. func NewStreamGraphMultiBranch[T any](condition StreamGraphMultiBranchCondition[T], endNodes map[string]bool) *GraphBranch { condRun := func(ctx context.Context, in *schema.StreamReader[T], opts ...any) ([]string, error) { ends, err := condition(ctx, in) if err != nil { return nil, err } ret := make([]string, 0, len(ends)) for end := range ends { if !endNodes[end] { return nil, fmt.Errorf("branch invocation returns unintended end node: %s", end) } ret = append(ret, end) } return ret, nil } return newGraphBranch(newRunnablePacker(nil, nil, condRun, nil, false), endNodes) } // NewGraphBranch creates a new graph branch. // It is used to determine the next node based on the condition. // e.g. // // condition := func(ctx context.Context, in string) (string, error) { // // logic to determine the next node // return "next_node_key", nil // } // endNodes := map[string]bool{"path01": true, "path02": true} // branch := compose.NewGraphBranch(condition, endNodes) // // graph.AddBranch("key_of_node_before_branch", branch) func NewGraphBranch[T any](condition GraphBranchCondition[T], endNodes map[string]bool) *GraphBranch { return NewGraphMultiBranch(func(ctx context.Context, in T) (endNode map[string]bool, err error) { ret, err := condition(ctx, in) if err != nil { return nil, err } return map[string]bool{ret: true}, nil }, endNodes) } // NewStreamGraphBranch creates a new stream graph branch. // It is used to determine the next node based on the condition of stream input. // e.g. // // condition := func(ctx context.Context, in *schema.StreamReader[T]) (string, error) { // // logic to determine the next node. // // to use the feature of stream, you can use the first chunk to determine the next node. // return "next_node_key", nil // } // endNodes := map[string]bool{"path01": true, "path02": true} // branch := compose.NewStreamGraphBranch(condition, endNodes) // // graph.AddBranch("key_of_node_before_branch", branch) func NewStreamGraphBranch[T any](condition StreamGraphBranchCondition[T], endNodes map[string]bool) *GraphBranch { return NewStreamGraphMultiBranch(func(ctx context.Context, in *schema.StreamReader[T]) (endNode map[string]bool, err error) { ret, err := condition(ctx, in) if err != nil { return nil, err } return map[string]bool{ret: true}, nil }, endNodes) } ================================================ FILE: compose/branch_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "io" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestMultiBranch(t *testing.T) { g := NewGraph[string, map[string]any]() emptyLambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }) err := g.AddLambdaNode("1", emptyLambda, WithOutputKey("1")) assert.NoError(t, err) err = g.AddLambdaNode("2", emptyLambda, WithOutputKey("2")) assert.NoError(t, err) err = g.AddLambdaNode("3", emptyLambda, WithOutputKey("3")) assert.NoError(t, err) err = g.AddBranch(START, NewGraphMultiBranch(func(ctx context.Context, in string) (endNode map[string]bool, err error) { return map[string]bool{"1": true, "2": true}, nil }, map[string]bool{"1": true, "2": true, "3": true})) assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx) assert.NoError(t, err) result, err := r.Invoke(ctx, "start") assert.NoError(t, err) assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) streamResult, err := r.Stream(ctx, "start") assert.NoError(t, err) result = map[string]any{} for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) for k, v := range chunk { result[k] = v } } assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) } func TestStreamMultiBranch(t *testing.T) { g := NewGraph[string, map[string]any]() emptyLambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }) err := g.AddLambdaNode("1", emptyLambda, WithOutputKey("1")) assert.NoError(t, err) err = g.AddLambdaNode("2", emptyLambda, WithOutputKey("2")) assert.NoError(t, err) err = g.AddLambdaNode("3", emptyLambda, WithOutputKey("3")) assert.NoError(t, err) err = g.AddBranch(START, NewStreamGraphMultiBranch(func(ctx context.Context, in *schema.StreamReader[string]) (endNode map[string]bool, err error) { in.Close() return map[string]bool{"1": true, "2": true}, nil }, map[string]bool{"1": true, "2": true, "3": true})) assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx) assert.NoError(t, err) result, err := r.Invoke(ctx, "start") assert.NoError(t, err) assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) streamResult, err := r.Stream(ctx, "start") assert.NoError(t, err) result = map[string]any{} for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) for k, v := range chunk { result[k] = v } } assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) } ================================================ FILE: compose/chain.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "reflect" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/gmap" "github.com/cloudwego/eino/internal/gslice" ) // NewChain create a chain with input/output type. func NewChain[I, O any](opts ...NewGraphOption) *Chain[I, O] { ch := &Chain[I, O]{ gg: NewGraph[I, O](opts...), } ch.gg.cmp = ComponentOfChain return ch } // Chain is a chain of components. // Chain nodes can be parallel / branch / sequence components. // Chain is designed to be used in a builder pattern (should Compile() before use). // And the interface is `Chain style`, you can use it like: `chain.AppendXX(...).AppendXX(...)` // // Normal usage: // 1. create a chain with input/output type: `chain := NewChain[inputType, outputType]()` // 2. add components to chainable list: // 2.1 add components: `chain.AppendChatTemplate(...).AppendChatModel(...).AppendToolsNode(...)` // 2.2 add parallel or branch node if needed: `chain.AppendParallel()`, `chain.AppendBranch()` // 3. compile: `r, err := c.Compile()` // 4. run: // 4.1 `one input & one output` use `r.Invoke(ctx, input)` // 4.2 `one input & multi output chunk` use `r.Stream(ctx, input)` // 4.3 `multi input chunk & one output` use `r.Collect(ctx, inputReader)` // 4.4 `multi input chunk & multi output chunk` use `r.Transform(ctx, inputReader)` // // Using in graph or other chain: // chain1 := NewChain[inputType, outputType]() // graph := NewGraph[](runTypePregel) // graph.AddGraph("key", chain1) // chain is an AnyGraph implementation // // // or in another chain: // chain2 := NewChain[inputType, outputType]() // chain2.AppendGraph(chain1) type Chain[I, O any] struct { err error gg *Graph[I, O] nodeIdx int preNodeKeys []string hasEnd bool } // ErrChainCompiled is returned when attempting to modify a chain after it has been compiled var ErrChainCompiled = errors.New("chain has been compiled, cannot be modified") // implements AnyGraph. func (c *Chain[I, O]) compile(ctx context.Context, option *graphCompileOptions) (*composableRunnable, error) { if err := c.addEndIfNeeded(); err != nil { return nil, err } return c.gg.compile(ctx, option) } // addEndIfNeeded add END edge of the chain/graph. // only run once when compiling. func (c *Chain[I, O]) addEndIfNeeded() error { if c.hasEnd { return nil } if c.err != nil { return c.err } if len(c.preNodeKeys) == 0 { return fmt.Errorf("pre node keys not set, number of nodes in chain= %d", len(c.gg.nodes)) } for _, nodeKey := range c.preNodeKeys { err := c.gg.AddEdge(nodeKey, END) if err != nil { return err } } c.hasEnd = true return nil } func (c *Chain[I, O]) getGenericHelper() *genericHelper { return newGenericHelper[I, O]() } // inputType returns the input type of the chain. // implements AnyGraph. func (c *Chain[I, O]) inputType() reflect.Type { return generic.TypeOf[I]() } // outputType returns the output type of the chain. // implements AnyGraph. func (c *Chain[I, O]) outputType() reflect.Type { return generic.TypeOf[O]() } // compositeType returns the composite type of the chain. // implements AnyGraph. func (c *Chain[I, O]) component() component { return c.gg.component() } // Compile to a Runnable. // Runnable can be used directly. // e.g. // // chain := NewChain[string, string]() // r, err := chain.Compile() // if err != nil {} // // r.Invoke(ctx, input) // ping => pong // r.Stream(ctx, input) // ping => stream out // r.Collect(ctx, inputReader) // stream in => pong // r.Transform(ctx, inputReader) // stream in => stream out func (c *Chain[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { if err := c.addEndIfNeeded(); err != nil { return nil, err } return c.gg.Compile(ctx, opts...) } // AppendChatModel add a ChatModel node to the chain. // e.g. // // model, err := openai.NewChatModel(ctx, config) // if err != nil {...} // chain.AppendChatModel(model) func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toChatModelNode(node, opts...) c.addNode(gNode, options) return c } // AppendChatTemplate add a ChatTemplate node to the chain. // eg. // // chatTemplate, err := prompt.FromMessages(schema.FString, &schema.Message{ // Role: schema.System, // Content: "You are acting as a {role}.", // }) // // chain.AppendChatTemplate(chatTemplate) func (c *Chain[I, O]) AppendChatTemplate(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toChatTemplateNode(node, opts...) c.addNode(gNode, options) return c } // AppendToolsNode add a ToolsNode node to the chain. // e.g. // // toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ // Tools: []tools.Tool{...}, // }) // // chain.AppendToolsNode(toolsNode) func (c *Chain[I, O]) AppendToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toToolsNode(node, opts...) c.addNode(gNode, options) return c } // AppendDocumentTransformer add a DocumentTransformer node to the chain. // e.g. // // markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) // // chain.AppendDocumentTransformer(markdownSplitter) func (c *Chain[I, O]) AppendDocumentTransformer(node document.Transformer, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toDocumentTransformerNode(node, opts...) c.addNode(gNode, options) return c } // AppendLambda add a Lambda node to the chain. // Lambda is a node that can be used to implement custom logic. // e.g. // // lambdaNode := compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) (string, error) {...}) // chain.AppendLambda(lambdaNode) // // Note: // to create a Lambda node, you need to use `compose.AnyLambda` or `compose.InvokableLambda` or `compose.StreamableLambda` or `compose.TransformableLambda`. // if you want this node has real stream output, you need to use `compose.StreamableLambda` or `compose.TransformableLambda`, for example. func (c *Chain[I, O]) AppendLambda(node *Lambda, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toLambdaNode(node, opts...) c.addNode(gNode, options) return c } // AppendEmbedding add a Embedding node to the chain. // e.g. // // embedder, err := openai.NewEmbedder(ctx, config) // if err != nil {...} // chain.AppendEmbedding(embedder) func (c *Chain[I, O]) AppendEmbedding(node embedding.Embedder, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toEmbeddingNode(node, opts...) c.addNode(gNode, options) return c } // AppendRetriever add a Retriever node to the chain. // e.g. // // retriever, err := vectorstore.NewRetriever(ctx, config) // if err != nil {...} // chain.AppendRetriever(retriever) // // or using fornax knowledge as retriever: // // config := fornaxknowledge.Config{...} // retriever, err := fornaxknowledge.NewKnowledgeRetriever(ctx, config) // if err != nil {...} // chain.AppendRetriever(retriever) func (c *Chain[I, O]) AppendRetriever(node retriever.Retriever, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toRetrieverNode(node, opts...) c.addNode(gNode, options) return c } // AppendLoader adds a Loader node to the chain. // e.g. // // loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{}) // if err != nil {...} // chain.AppendLoader(loader) func (c *Chain[I, O]) AppendLoader(node document.Loader, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toLoaderNode(node, opts...) c.addNode(gNode, options) return c } // AppendIndexer add an Indexer node to the chain. // Indexer is a node that can store documents. // e.g. // // vectorStoreImpl, err := vikingdb.NewVectorStorer(ctx, vikingdbConfig) // in components/vectorstore/vikingdb/vectorstore.go // if err != nil {...} // // config := vectorstore.IndexerConfig{VectorStore: vectorStoreImpl} // indexer, err := vectorstore.NewIndexer(ctx, config) // if err != nil {...} // // chain.AppendIndexer(indexer) func (c *Chain[I, O]) AppendIndexer(node indexer.Indexer, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toIndexerNode(node, opts...) c.addNode(gNode, options) return c } // AppendBranch add a conditional branch to chain. // Each branch within the ChainBranch can be an AnyGraph. // All branches should either lead to END, or converge to another node within the Chain. // e.g. // // cb := compose.NewChainBranch(conditionFunc) // cb.AddChatTemplate("chat_template_key_01", chatTemplate) // cb.AddChatTemplate("chat_template_key_02", chatTemplate2) // chain.AppendBranch(cb) func (c *Chain[I, O]) AppendBranch(b *ChainBranch) *Chain[I, O] { if b == nil { c.reportError(fmt.Errorf("append branch invalid, branch is nil")) return c } if b.err != nil { c.reportError(fmt.Errorf("append branch error: %w", b.err)) return c } if len(b.key2BranchNode) == 0 { c.reportError(fmt.Errorf("append branch invalid, nodeList is empty")) return c } if len(b.key2BranchNode) == 1 { c.reportError(fmt.Errorf("append branch invalid, nodeList length = 1")) return c } var startNode string if len(c.preNodeKeys) == 0 { // branch appended directly to START startNode = START } else if len(c.preNodeKeys) == 1 { startNode = c.preNodeKeys[0] } else { c.reportError(fmt.Errorf("append branch invalid, multiple previous nodes: %v ", c.preNodeKeys)) return c } prefix := c.nextNodeKey() key2NodeKey := make(map[string]string, len(b.key2BranchNode)) for key := range b.key2BranchNode { node := b.key2BranchNode[key] var nodeKey string if node.Second != nil && node.Second.nodeOptions != nil && node.Second.nodeOptions.nodeKey != "" { nodeKey = node.Second.nodeOptions.nodeKey } else { nodeKey = fmt.Sprintf("%s_branch_%s", prefix, key) } if err := c.gg.addNode(nodeKey, node.First, node.Second); err != nil { c.reportError(fmt.Errorf("add branch node[%s] to chain failed: %w", nodeKey, err)) return c } key2NodeKey[key] = nodeKey } gBranch := *b.internalBranch invokeCon := func(ctx context.Context, in any) (endNode []string, err error) { ends, err := b.internalBranch.invoke(ctx, in) if err != nil { return nil, err } nodeKeyEnds := make([]string, 0, len(ends)) for _, end := range ends { if nodeKey, ok := key2NodeKey[end]; !ok { return nil, fmt.Errorf("branch invocation returns unintended end node: %s", end) } else { nodeKeyEnds = append(nodeKeyEnds, nodeKey) } } return nodeKeyEnds, nil } gBranch.invoke = invokeCon collectCon := func(ctx context.Context, sr streamReader) ([]string, error) { ends, err := b.internalBranch.collect(ctx, sr) if err != nil { return nil, err } nodeKeyEnds := make([]string, 0, len(ends)) for _, end := range ends { if nodeKey, ok := key2NodeKey[end]; !ok { return nil, fmt.Errorf("branch invocation returns unintended end node: %s", end) } else { nodeKeyEnds = append(nodeKeyEnds, nodeKey) } } return nodeKeyEnds, nil } gBranch.collect = collectCon gBranch.endNodes = gslice.ToMap(gmap.Values(key2NodeKey), func(k string) (string, bool) { return k, true }) if err := c.gg.AddBranch(startNode, &gBranch); err != nil { c.reportError(fmt.Errorf("chain append branch failed: %w", err)) return c } c.preNodeKeys = gmap.Values(key2NodeKey) return c } // AppendParallel add a Parallel structure (multiple concurrent nodes) to the chain. // e.g. // // parallel := compose.NewParallel() // parallel.AddChatModel("openai", model1) // => "openai": *schema.Message{} // parallel.AddChatModel("maas", model2) // => "maas": *schema.Message{} // // chain.AppendParallel(parallel) // => multiple concurrent nodes are added to the Chain // // The next node in the chain is either an END, or a node which accepts a map[string]any, where keys are `openai` `maas` as specified above. func (c *Chain[I, O]) AppendParallel(p *Parallel) *Chain[I, O] { if p == nil { c.reportError(fmt.Errorf("append parallel invalid, parallel is nil")) return c } if p.err != nil { c.reportError(fmt.Errorf("append parallel invalid, parallel error: %w", p.err)) return c } if len(p.nodes) <= 1 { c.reportError(fmt.Errorf("append parallel invalid, not enough nodes, count = %d", len(p.nodes))) return c } var startNode string if len(c.preNodeKeys) == 0 { // parallel appended directly to START startNode = START } else if len(c.preNodeKeys) == 1 { startNode = c.preNodeKeys[0] } else { c.reportError(fmt.Errorf("append parallel invalid, multiple previous nodes: %v ", c.preNodeKeys)) return c } prefix := c.nextNodeKey() var nodeKeys []string for i := range p.nodes { node := p.nodes[i] var nodeKey string if node.Second != nil && node.Second.nodeOptions != nil && node.Second.nodeOptions.nodeKey != "" { nodeKey = node.Second.nodeOptions.nodeKey } else { nodeKey = fmt.Sprintf("%s_parallel_%d", prefix, i) } if err := c.gg.addNode(nodeKey, node.First, node.Second); err != nil { c.reportError(fmt.Errorf("add parallel node to chain failed, key=%s, err: %w", nodeKey, err)) return c } if err := c.gg.AddEdge(startNode, nodeKey); err != nil { c.reportError(fmt.Errorf("add parallel edge failed, from=%s, to=%s, err: %w", startNode, nodeKey, err)) return c } nodeKeys = append(nodeKeys, nodeKey) } c.preNodeKeys = nodeKeys return c } // AppendGraph add a AnyGraph node to the chain. // AnyGraph can be a chain or a graph. // e.g. // // graph := compose.NewGraph[string, string]() // chain.AppendGraph(graph) func (c *Chain[I, O]) AppendGraph(node AnyGraph, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toAnyGraphNode(node, opts...) c.addNode(gNode, options) return c } // AppendPassthrough add a Passthrough node to the chain. // Could be used to connect multiple ChainBranch or Parallel. // e.g. // // chain.AppendPassthrough() func (c *Chain[I, O]) AppendPassthrough(opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toPassthroughNode(opts...) c.addNode(gNode, options) return c } // nextIdx. // get the next idx for the chain. // chain key is: node_idx => eg: node_0 => represent the first node of the chain (idx start from 0) // if has parallel: node_idx_parallel_idx => eg: node_0_parallel_1 => represent the first node of the chain, and is a parallel node, and the second node of the parallel // if has branch: node_idx_branch_key => eg: node_1_branch_customkey => represent the second node of the chain, and is a branch node, and the 'customkey' is the key of the branch func (c *Chain[I, O]) nextNodeKey() string { idx := c.nodeIdx c.nodeIdx++ return fmt.Sprintf("node_%d", idx) } // reportError. // save the first error in the chain. func (c *Chain[I, O]) reportError(err error) { if c.err == nil { c.err = err } } // addNode. // add a node to the chain. func (c *Chain[I, O]) addNode(node *graphNode, options *graphAddNodeOpts) { if c.err != nil { return } if c.gg.compiled { c.reportError(ErrChainCompiled) return } if node == nil { c.reportError(fmt.Errorf("chain add node invalid, node is nil")) return } nodeKey := options.nodeOptions.nodeKey defaultNodeKey := c.nextNodeKey() if nodeKey == "" { nodeKey = defaultNodeKey } err := c.gg.addNode(nodeKey, node, options) if err != nil { c.reportError(err) return } if len(c.preNodeKeys) == 0 { c.preNodeKeys = append(c.preNodeKeys, START) } for _, preNodeKey := range c.preNodeKeys { e := c.gg.AddEdge(preNodeKey, nodeKey) if e != nil { c.reportError(e) return } } c.preNodeKeys = []string{nodeKey} } ================================================ FILE: compose/chain_branch.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) type nodeOptionsPair generic.Pair[*graphNode, *graphAddNodeOpts] // ChainBranch represents a conditional branch in a chain of operations. // It allows for dynamic routing of execution based on a condition. // All branches within ChainBranch are expected to either end the Chain, or converge to another node in the Chain. type ChainBranch struct { internalBranch *GraphBranch key2BranchNode map[string]nodeOptionsPair err error } // NewChainMultiBranch creates a chain branch where a condition selects // multiple end nodes to route execution. func NewChainMultiBranch[T any](cond GraphMultiBranchCondition[T]) *ChainBranch { invokeCond := func(ctx context.Context, in T, opts ...any) (endNodes []string, err error) { ends, err := cond(ctx, in) if err != nil { return nil, err } endNodes = make([]string, 0, len(ends)) for end := range ends { endNodes = append(endNodes, end) } return endNodes, nil } return &ChainBranch{ key2BranchNode: make(map[string]nodeOptionsPair), internalBranch: newGraphBranch(newRunnablePacker(invokeCond, nil, nil, nil, false), nil), } } // NewStreamChainMultiBranch creates a chain branch that selects multiple end // nodes based on a condition evaluated on the input stream. func NewStreamChainMultiBranch[T any](cond StreamGraphMultiBranchCondition[T]) *ChainBranch { collectCon := func(ctx context.Context, in *schema.StreamReader[T], opts ...any) (endNodes []string, err error) { ends, err := cond(ctx, in) if err != nil { return nil, err } endNodes = make([]string, 0, len(ends)) for end := range ends { endNodes = append(endNodes, end) } return endNodes, nil } return &ChainBranch{ key2BranchNode: make(map[string]nodeOptionsPair), internalBranch: newGraphBranch(newRunnablePacker(nil, nil, collectCon, nil, false), nil), } } // NewChainBranch creates a new ChainBranch instance based on a given condition. // It takes a generic type T and a GraphBranchCondition function for that type. // The returned ChainBranch will have an empty key2BranchNode map and a condition function // that wraps the provided cond to handle type assertions and error checking. // eg. // // condition := func(ctx context.Context, in string, opts ...any) (endNode string, err error) { // // logic to determine the next node // return "some_next_node_key", nil // } // // cb := NewChainBranch[string](condition) // cb.AddPassthrough("next_node_key_01", xxx) // node in branch, represent one path of branch // cb.AddPassthrough("next_node_key_02", xxx) // node in branch func NewChainBranch[T any](cond GraphBranchCondition[T]) *ChainBranch { return NewChainMultiBranch(func(ctx context.Context, in T) (endNode map[string]bool, err error) { ret, err := cond(ctx, in) if err != nil { return nil, err } return map[string]bool{ret: true}, nil }) } // NewStreamChainBranch creates a new ChainBranch instance based on a given stream condition. // It takes a generic type T and a StreamGraphBranchCondition function for that type. // The returned ChainBranch will have an empty key2BranchNode map and a condition function // that wraps the provided cond to handle type assertions and error checking. // eg. // // condition := func(ctx context.Context, in *schema.StreamReader[string], opts ...any) (endNode string, err error) { // // logic to determine the next node, you can read the stream and make a decision. // // to save time, usually read the first chunk of stream, then make a decision which path to go. // return "some_next_node_key", nil // } // // cb := NewStreamChainBranch[string](condition) func NewStreamChainBranch[T any](cond StreamGraphBranchCondition[T]) *ChainBranch { return NewStreamChainMultiBranch(func(ctx context.Context, in *schema.StreamReader[T]) (endNodes map[string]bool, err error) { ret, err := cond(ctx, in) if err != nil { return nil, err } return map[string]bool{ret: true}, nil }) } // AddChatModel adds a ChatModel node to the branch. // eg. // // chatModel01, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ // Model: "gpt-4o", // }) // chatModel02, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ // Model: "gpt-4o-mini", // }) // cb.AddChatModel("chat_model_key_01", chatModel01) // cb.AddChatModel("chat_model_key_02", chatModel02) func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toChatModelNode(node, opts...) return cb.addNode(key, gNode, options) } // AddChatTemplate adds a ChatTemplate node to the branch. // eg. // // chatTemplate, err := prompt.FromMessages(schema.FString, &schema.Message{ // Role: schema.System, // Content: "You are acting as a {role}.", // }) // // cb.AddChatTemplate("chat_template_key_01", chatTemplate) // // chatTemplate2, err := prompt.FromMessages(schema.FString, &schema.Message{ // Role: schema.System, // Content: "You are acting as a {role}, you are not allowed to chat in other topics.", // }) // // cb.AddChatTemplate("chat_template_key_02", chatTemplate2) func (cb *ChainBranch) AddChatTemplate(key string, node prompt.ChatTemplate, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toChatTemplateNode(node, opts...) return cb.addNode(key, gNode, options) } // AddToolsNode adds a ToolsNode to the branch. // eg. // // toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ // Tools: []tools.Tool{...}, // }) // // cb.AddToolsNode("tools_node_key", toolsNode) func (cb *ChainBranch) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toToolsNode(node, opts...) return cb.addNode(key, gNode, options) } // AddLambda adds a Lambda node to the branch. // eg. // // lambdaFunc := func(ctx context.Context, in string, opts ...any) (out string, err error) { // // logic to process the input // return "processed_output", nil // } // // cb.AddLambda("lambda_node_key", compose.InvokeLambda(lambdaFunc)) func (cb *ChainBranch) AddLambda(key string, node *Lambda, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toLambdaNode(node, opts...) return cb.addNode(key, gNode, options) } // AddEmbedding adds an Embedding node to the branch. // eg. // // embeddingNode, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{ // Model: "text-embedding-3-small", // }) // // cb.AddEmbedding("embedding_node_key", embeddingNode) func (cb *ChainBranch) AddEmbedding(key string, node embedding.Embedder, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toEmbeddingNode(node, opts...) return cb.addNode(key, gNode, options) } // AddRetriever adds a Retriever node to the branch. // eg. // // retriever, err := volc_vikingdb.NewRetriever(ctx, &volc_vikingdb.RetrieverConfig{ // Collection: "my_collection", // }) // // cb.AddRetriever("retriever_node_key", retriever) func (cb *ChainBranch) AddRetriever(key string, node retriever.Retriever, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toRetrieverNode(node, opts...) return cb.addNode(key, gNode, options) } // AddLoader adds a Loader node to the branch. // eg. // // pdfParser, err := pdf.NewPDFParser() // loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{ // Parser: pdfParser, // }) // // cb.AddLoader("loader_node_key", loader) func (cb *ChainBranch) AddLoader(key string, node document.Loader, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toLoaderNode(node, opts...) return cb.addNode(key, gNode, options) } // AddIndexer adds an Indexer node to the branch. // eg. // // indexer, err := volc_vikingdb.NewIndexer(ctx, &volc_vikingdb.IndexerConfig{ // Collection: "my_collection", // }) // // cb.AddIndexer("indexer_node_key", indexer) func (cb *ChainBranch) AddIndexer(key string, node indexer.Indexer, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toIndexerNode(node, opts...) return cb.addNode(key, gNode, options) } // AddDocumentTransformer adds an Document Transformer node to the branch. // eg. // // markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) // // cb.AddDocumentTransformer("document_transformer_node_key", markdownSplitter) func (cb *ChainBranch) AddDocumentTransformer(key string, node document.Transformer, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toDocumentTransformerNode(node, opts...) return cb.addNode(key, gNode, options) } // AddGraph adds a generic Graph node to the branch. // eg. // // graph, err := compose.NewGraph[string, string]() // // cb.AddGraph("graph_node_key", graph) func (cb *ChainBranch) AddGraph(key string, node AnyGraph, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toAnyGraphNode(node, opts...) return cb.addNode(key, gNode, options) } // AddPassthrough adds a Passthrough node to the branch. // eg. // // cb.AddPassthrough("passthrough_node_key") func (cb *ChainBranch) AddPassthrough(key string, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toPassthroughNode(opts...) return cb.addNode(key, gNode, options) } func (cb *ChainBranch) addNode(key string, node *graphNode, options *graphAddNodeOpts) *ChainBranch { if cb.err != nil { return cb } if cb.key2BranchNode == nil { cb.key2BranchNode = make(map[string]nodeOptionsPair) } _, ok := cb.key2BranchNode[key] if ok { cb.err = fmt.Errorf("chain branch add node, duplicate branch node key= %s", key) return cb } cb.key2BranchNode[key] = nodeOptionsPair{node, options} return cb } ================================================ FILE: compose/chain_branch_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "io" "strconv" "strings" "testing" "unicode/utf8" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/schema" ) func TestChainBranch(t *testing.T) { cond := func(ctx context.Context, input string) (key string, err error) { switch input { case "one": return "one_key", nil case "two": return "two_key", nil case "three": return "three_key", nil default: return "", fmt.Errorf("invalid input= %s", input) } } t.Run("nested chain", func(t *testing.T) { inner := NewChain[string, string]() inner.AppendBranch(NewChainBranch(cond). AddLambda("one_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + in, nil })). AddLambda("two_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + in + in, nil }))) inner.AppendParallel(NewParallel(). AddLambda("one_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + in, nil })). AddLambda("two_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + in + in, nil }))) outer := NewChain[string, string]() outer.AppendGraph(inner) _, err := outer.Compile(context.Background()) assert.Error(t, err) }) t.Run("bad param", func(t *testing.T) { c := NewChain[string, string]() c.AppendBranch(nil) assert.NotNil(t, c.err) c = NewChain[string, string]() c.AppendBranch(NewChainBranch[string](nil)) assert.NotNil(t, c.err) c = NewChain[string, string]() c.AppendBranch(NewChainBranch(cond).AddChatTemplate("template", prompt.FromMessages(schema.FString, schema.SystemMessage("hello")))) assert.NotNil(t, c.err) c = NewChain[string, string]() c.AppendBranch(NewChainBranch(cond).AddChatTemplate("1", prompt.FromMessages(schema.FString)).AddChatTemplate("1", prompt.FromMessages(schema.FString))) assert.NotNil(t, c.err) }) t.Run("different Node types in branch", func(t *testing.T) { c := NewChain[string, string]() c.AppendBranch(NewChainBranch(cond). AddChatTemplate("t", prompt.FromMessages(schema.FString)). AddGraph("c", NewChain[string, string]())) assert.NotNil(t, c.err) }) t.Run("type mismatch", func(t *testing.T) { c := NewChain[int, string]() c.AppendBranch(NewChainBranch(cond). AddLambda("one_key", InvokableLambda(func(ctx context.Context, in int) (output string, err error) { return strconv.Itoa(in), nil })). AddLambda("two_key", InvokableLambda(func(ctx context.Context, in int) (output string, err error) { return strconv.Itoa(in), nil }))) _, err := c.Compile(context.Background()) assert.NotNil(t, err) }) t.Run("invoke", func(t *testing.T) { c := NewChain[string, string]() c.AppendBranch(NewChainBranch(cond). AddLambda("one_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + in, nil })). AddLambda("two_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + in + in, nil }))) c.AppendLambda(InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + in, nil })) assert.Nil(t, c.err) compiledChain, err := c.Compile(context.Background()) assert.Nil(t, err) out, err := compiledChain.Invoke(context.Background(), "two") assert.Nil(t, err) assert.Equal(t, "twotwotwotwotwotwo", out) _, err = compiledChain.Invoke(context.Background(), "three") assert.NotNil(t, err) _, err = compiledChain.Invoke(context.Background(), "four") assert.NotNil(t, err) }) t.Run("fake stream", func(t *testing.T) { c := NewChain[string, string]() c.AppendLambda(StreamableLambda(func(ctx context.Context, in string) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](utf8.RuneCountInString(in)) go func() { for _, field := range strings.Fields(in) { sw.Send(field, nil) } sw.Close() }() return sr, nil })) c.AppendBranch(NewChainBranch[string](cond).AddLambda("one_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { defer in.Close() for { v, err := in.Recv() if errors.Is(err, io.EOF) { break } if err != nil { return "", err } output += v } return output + output, nil })). AddLambda("two_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { defer in.Close() for { v, err := in.Recv() if errors.Is(err, io.EOF) { break } if err != nil { return "", err } output += v } return output + output + output, nil }))) assert.Nil(t, c.err) compiledChain, err := c.Compile(context.Background()) assert.Nil(t, err) out, err := compiledChain.Invoke(context.Background(), "one") assert.Nil(t, err) assert.Equal(t, "oneone", out) }) t.Run("real stream", func(t *testing.T) { streamCon := func(ctx context.Context, sr *schema.StreamReader[string]) (key string, err error) { msg, err := sr.Recv() if err != nil { return "", err } defer sr.Close() switch msg { case "one": return "one_key", nil case "two": return "two_key", nil case "three": return "three_key", nil default: return "", fmt.Errorf("invalid input= %s", msg) } } c := NewChain[string, string]() c.AppendLambda(StreamableLambda(func(ctx context.Context, in string) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](utf8.RuneCountInString(in)) go func() { for _, field := range strings.Fields(in) { sw.Send(field, nil) } sw.Close() }() return sr, nil })) c.AppendBranch(NewStreamChainBranch(streamCon).AddLambda("one_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { defer in.Close() for { v, err := in.Recv() if errors.Is(err, io.EOF) { break } if err != nil { return "", err } output += v } return output + output, nil })). AddLambda("two_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { defer in.Close() for { v, err := in.Recv() if errors.Is(err, io.EOF) { break } if err != nil { return "", err } output += v } return output + output + output, nil }))) assert.Nil(t, c.err) compiledChain, err := c.Compile(context.Background()) assert.Nil(t, err) out, err := compiledChain.Stream(context.Background(), "one size fit all") assert.Nil(t, err) concat, err := concatStreamReader(out) assert.Nil(t, err) assert.Equal(t, "onesizefitallonesizefitall", concat) }) } func TestChainMultiBranch(t *testing.T) { emptyLambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }) ctx := context.Background() r, err := NewChain[string, map[string]any](). AppendBranch(NewChainMultiBranch(func(ctx context.Context, in string) (endNode map[string]bool, err error) { return map[string]bool{"1": true, "2": true}, nil }).AddLambda("1", emptyLambda, WithOutputKey("1")).AddLambda("2", emptyLambda, WithOutputKey("2")).AddLambda("3", emptyLambda, WithOutputKey("3"))). Compile(ctx) assert.Nil(t, err) result, err := r.Invoke(ctx, "start") assert.NoError(t, err) assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) streamResult, err := r.Stream(ctx, "start") assert.NoError(t, err) result = map[string]any{} for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) for k, v := range chunk { result[k] = v } } assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) } func TestStreamChainMultiBranch(t *testing.T) { emptyLambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }) ctx := context.Background() r, err := NewChain[string, map[string]any](). AppendBranch(NewStreamChainMultiBranch(func(ctx context.Context, in *schema.StreamReader[string]) (endNode map[string]bool, err error) { return map[string]bool{"1": true, "2": true}, nil }).AddLambda("1", emptyLambda, WithOutputKey("1")).AddLambda("2", emptyLambda, WithOutputKey("2")).AddLambda("3", emptyLambda, WithOutputKey("3"))). Compile(ctx) assert.Nil(t, err) result, err := r.Invoke(ctx, "start") assert.NoError(t, err) assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) streamResult, err := r.Stream(ctx, "start") assert.NoError(t, err) result = map[string]any{} for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) for k, v := range chunk { result[k] = v } } assert.Equal(t, map[string]any{ "1": "start", "2": "start", }, result) } ================================================ FILE: compose/chain_parallel.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "fmt" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" ) // NewParallel creates a new parallel type. // it is useful when you want to run multiple nodes in parallel in a chain. func NewParallel() *Parallel { return &Parallel{ outputKeys: make(map[string]bool), } } // Parallel run multiple nodes in parallel // // use `NewParallel()` to create a new parallel type // Example: // // parallel := NewParallel() // parallel.AddChatModel("output_key01", chat01) // parallel.AddChatModel("output_key01", chat02) // // chain := NewChain[any,any]() // chain.AppendParallel(parallel) type Parallel struct { nodes []nodeOptionsPair outputKeys map[string]bool err error } // AddChatModel adds a chat model to the parallel. // eg. // // chatModel01, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ // Model: "gpt-4o", // }) // // chatModel02, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ // Model: "gpt-4o", // }) // // p.AddChatModel("output_key01", chatModel01) // p.AddChatModel("output_key02", chatModel02) func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toChatModelNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddChatTemplate adds a chat template to the parallel. // eg. // // chatTemplate01, err := prompt.FromMessages(schema.FString, &schema.Message{ // Role: schema.System, // Content: "You are acting as a {role}.", // }) // // p.AddChatTemplate("output_key01", chatTemplate01) func (p *Parallel) AddChatTemplate(outputKey string, node prompt.ChatTemplate, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toChatTemplateNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddToolsNode adds a tools node to the parallel. // eg. // // toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ // Tools: []tool.BaseTool{...}, // }) // // p.AddToolsNode("output_key01", toolsNode) func (p *Parallel) AddToolsNode(outputKey string, node *ToolsNode, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toToolsNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddLambda adds a lambda node to the parallel. // eg. // // lambdaFunc := func(ctx context.Context, input *schema.Message) ([]*schema.Message, error) { // return []*schema.Message{input}, nil // } // // p.AddLambda("output_key01", compose.InvokeLambda(lambdaFunc)) func (p *Parallel) AddLambda(outputKey string, node *Lambda, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toLambdaNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddEmbedding adds an embedding node to the parallel. // eg. // // embeddingNode, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{ // Model: "text-embedding-3-small", // }) // // p.AddEmbedding("output_key01", embeddingNode) func (p *Parallel) AddEmbedding(outputKey string, node embedding.Embedder, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toEmbeddingNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddRetriever adds a retriever node to the parallel. // eg. // // retriever, err := vikingdb.NewRetriever(ctx, &vikingdb.RetrieverConfig{}) // // p.AddRetriever("output_key01", retriever) func (p *Parallel) AddRetriever(outputKey string, node retriever.Retriever, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toRetrieverNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddLoader adds a loader node to the parallel. // eg. // // loader, err := file.NewLoader(ctx, &file.LoaderConfig{}) // // p.AddLoader("output_key01", loader) func (p *Parallel) AddLoader(outputKey string, node document.Loader, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toLoaderNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddIndexer adds an indexer node to the parallel. // eg. // // indexer, err := volc_vikingdb.NewIndexer(ctx, &volc_vikingdb.IndexerConfig{ // Collection: "my_collection", // }) // // p.AddIndexer("output_key01", indexer) func (p *Parallel) AddIndexer(outputKey string, node indexer.Indexer, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toIndexerNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddDocumentTransformer adds an Document Transformer node to the parallel. // eg. // // markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) // // p.AddDocumentTransformer("output_key01", markdownSplitter) func (p *Parallel) AddDocumentTransformer(outputKey string, node document.Transformer, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toDocumentTransformerNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddGraph adds a graph node to the parallel. // It is useful when you want to use a graph or a chain as a node in the parallel. // eg. // // graph, err := compose.NewChain[any,any]() // // p.AddGraph("output_key01", graph) func (p *Parallel) AddGraph(outputKey string, node AnyGraph, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toAnyGraphNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } // AddPassthrough adds a passthrough node to the parallel. // eg. // // p.AddPassthrough("output_key01") func (p *Parallel) AddPassthrough(outputKey string, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toPassthroughNode(append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } func (p *Parallel) addNode(outputKey string, node *graphNode, options *graphAddNodeOpts) *Parallel { if p.err != nil { return p } if node == nil { p.err = fmt.Errorf("chain parallel add node invalid, node is nil") return p } if p.outputKeys == nil { p.outputKeys = make(map[string]bool) } if _, ok := p.outputKeys[outputKey]; ok { p.err = fmt.Errorf("parallel add node err, duplicate output key= %s", outputKey) return p } if node.nodeInfo == nil { p.err = fmt.Errorf("chain parallel add node invalid, nodeInfo is nil") return p } node.nodeInfo.outputKey = outputKey p.nodes = append(p.nodes, nodeOptionsPair{node, options}) p.outputKeys[outputKey] = true return p } ================================================ FILE: compose/chain_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "math/rand" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/internal/mock/components/document" "github.com/cloudwego/eino/internal/mock/components/embedding" "github.com/cloudwego/eino/internal/mock/components/indexer" "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/internal/mock/components/retriever" "github.com/cloudwego/eino/schema" ) func TestChain(t *testing.T) { cm := &mockIntentChatModel{} // 构建 branch branchCond := func(ctx context.Context, input map[string]any) (string, error) { if rand.Intn(2) == 1 { return "b1", nil } return "b2", nil } b1 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("hello in branch lambda 01") kvs["role"] = "cat" return kvs, nil }) b2 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("hello in branch lambda 02") kvs["role"] = "dog" return kvs, nil }) // 并发节点 parallel := NewParallel() parallel. AddLambda("role", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { // may be change role to others by input kvs, for example (dentist/doctor...) role := kvs["role"] if role.(string) == "" { role = "bird" } return role.(string), nil })). AddLambda("input", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { return "你的叫声是怎样的?", nil })) // 顺序节点 rolePlayChain := NewChain[map[string]any, *schema.Message]() rolePlayChain. AppendChatTemplate(prompt.FromMessages(schema.FString, schema.SystemMessage(`You are a {role}.`), schema.UserMessage(`{input}`))). AppendChatModel(cm) // 构建 chain chain := NewChain[map[string]any, string]() chain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { // do some logic to prepare kv as variables for next Node // just pass through t.Log("in view lambda: ", kvs) return kvs, nil })). AppendBranch(NewChainBranch[map[string]any](branchCond).AddLambda("b1", b1).AddLambda("b2", b2)). AppendPassthrough(). AppendParallel(parallel). AppendGraph(rolePlayChain). AppendLambda(InvokableLambda(func(ctx context.Context, m *schema.Message) (string, error) { // do some logic to check the output or something t.Log("in view of messages: ", m.Content) return m.Content, nil })) r, err := chain.Compile(context.Background()) assert.Nil(t, err) out, err := r.Invoke(context.Background(), map[string]any{}) assert.Nil(t, err) t.Log(err) t.Log("out is : ", out) } func TestChainWithException(t *testing.T) { chain := NewChain[map[string]any, string]() chain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { // do some logic to prepare kv as variables for next Node // just pass through t.Log("in view lambda: ", kvs) return kvs, nil })). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in view lambda 02: ", kvs) return kvs, nil }), WithNodeKey("xlam")) // items with parallels parallel := NewParallel() parallel. AddLambda("hello", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { t.Log("in parallel item 01") return "world", nil })). AddLambda("world", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { t.Log("in parallel item 02") return "hello", nil })) // sequence items nchain := NewChain[map[string]any, map[string]any]() nchain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in sequence item 01") return kvs, nil })). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in sequence item 02") return kvs, nil })) branchCond := func(ctx context.Context, input map[string]any) (string, error) { if rand.Intn(2) == 1 { return "b1", nil } return "b2", nil } b1 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("hello in branch lambda 01") kvs["role"] = "cat" return kvs, nil }) b2 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }) // sequence with branch chain.AppendBranch(NewChainBranch[map[string]any](branchCond).AddLambda("b1", b1).AddLambda("b2", b2)) // parallel with sequence parallel.AddGraph("test_sequence", nchain) // parallel with parallel npara := NewParallel(). AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })). AddLambda("test_parallel2", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) // parallel with graph ngraph := NewChain[map[string]any, map[string]any](). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in graph item 01") return kvs, nil })). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in graph item 02") return kvs, nil })) nc := NewChain[map[string]any, map[string]any]() nc.AppendGraph(ngraph) parallel.AddGraph("test_graph", nc) chain.AppendPassthrough() // sequence with parallel chain.AppendParallel(npara) // 构建 chain chain. AppendGraph(nchain). AppendParallel(parallel). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { t.Log("in last view lambda: ", kvs) return "hello last", nil })) ctx := context.Background() r, err := chain.Compile(ctx) assert.Nil(t, err) out, err := r.Invoke(ctx, map[string]any{"test": "test"}) assert.Nil(t, err) t.Log("out is : ", out) } func TestEmptyList(t *testing.T) { ctx := context.Background() // no nodes in chain chain := NewChain[map[string]any, map[string]any]() _, err := chain.Compile(ctx) assert.Error(t, err) // no nodes in parallel parallel := NewParallel() chain = NewChain[map[string]any, map[string]any]() chain.AppendParallel(parallel) _, err = chain.Compile(ctx) assert.Error(t, err) // no nodes in sequence emptyChain := NewChain[map[string]any, map[string]any]() chain = NewChain[map[string]any, map[string]any]() chain. AppendParallel(parallel). AppendGraph(emptyChain). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) _, err = chain.Compile(ctx) assert.Error(t, err) } func TestChainList(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in view lambda: ", kvs) return kvs, nil })) // parallel parallel := NewParallel() parallel. AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in parallel item 01") return kvs, nil })) // seq in parallel nchain := NewChain[map[string]any, map[string]any]() nchain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in sequence in parallel item 01") return kvs, nil })). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in sequence in parallel item 02") return kvs, nil })) // seq in seq nchainInChain := NewChain[map[string]any, map[string]any]() nchainInChain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in sequence in sequence item 01") return kvs, nil })). AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in sequence in sequence item 02") return kvs, nil })) nchain.AppendGraph(nchainInChain) parallel.AddGraph("test_seq_in_parallel", nchain) chain.AppendParallel(parallel) r, err := chain.Compile(context.Background()) assert.Nil(t, err) out, err := r.Invoke(context.Background(), map[string]any{"test": "test"}) assert.Nil(t, err) t.Log("out is : ", out) } func TestChainSingleNode(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in view lambda: ", kvs) return kvs, nil })) // single Node in chain (prepare for parallel) singleNodeChain := NewChain[map[string]any, map[string]any]() singleNodeChain. AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in sequence item 01") return kvs, nil })) // add parallel parallel := NewParallel() parallel. AddLambda("test_parallel1_lambda", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in parallel item 01") return kvs, nil })) parallel.AddGraph("test_parallel2_chain", singleNodeChain) ctx := context.Background() chain.AppendParallel(parallel) r, err := chain.Compile(ctx) assert.Nil(t, err) out, err := r.Invoke(ctx, map[string]any{"test": "test"}) assert.Nil(t, err) t.Log("out is : ", out) } func TestParallelModels(t *testing.T) { cm := &mockIntentChatModel{} chain := NewChain[map[string]any, map[string]any]() chatSuite := NewChain[map[string]any, string]() chatSuite. AppendChatTemplate(prompt.FromMessages(schema.FString, schema.SystemMessage(`You are a {role}.`), schema.UserMessage(`{input}`))). AppendChatModel(cm). AppendLambda(InvokableLambda(func(ctx context.Context, msg *schema.Message) (string, error) { t.Log("in parallel item 01") return msg.Content, nil })) parallel := NewParallel() parallel. AddGraph("time001", chatSuite). AddGraph("time002", chatSuite). AddGraph("time003", chatSuite) chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { t.Log("in view lambda: ", kvs) return kvs, nil })) chain.AppendParallel(parallel) ctx := context.Background() r, err := chain.Compile(ctx) assert.Nil(t, err) out, err := r.Invoke(ctx, map[string]any{"role": "cat", "input": "你怎么叫的?"}) assert.Nil(t, err) t.Log("out is : ", out) } func TestChainMultiNodes(t *testing.T) { ctx := context.Background() t.Run("test embedding Node", func(t *testing.T) { chain := NewChain[[]string, [][]float64]() mockCtrl := gomock.NewController(t) eb := embedding.NewMockEmbedder(mockCtrl) chain.AppendEmbedding(eb) r, err := chain.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, r) }) t.Run("test retriever Node", func(t *testing.T) { chain := NewChain[string, []*schema.Document]() chain.AppendRetriever(retriever.NewMockRetriever(gomock.NewController(t))) r, err := chain.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, r) }) t.Run("test chat model", func(t *testing.T) { chain := NewChain[[]*schema.Message, *schema.Message]() cm := &mockIntentChatModel{} chain.AppendChatModel(cm) r, err := chain.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, r) }) t.Run("test chat template", func(t *testing.T) { chain := NewChain[map[string]any, []*schema.Message]() chatTemplate := prompt.FromMessages(schema.FString) chain.AppendChatTemplate(chatTemplate) r, err := chain.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, r) }) t.Run("test lambda", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) r, err := chain.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, r) }) t.Run("test indexer", func(t *testing.T) { chain := NewChain[[]*schema.Document, []string]() chain.AppendIndexer(indexer.NewMockIndexer(gomock.NewController(t))) r, err := chain.Compile(ctx) assert.NoError(t, err) assert.NotNil(t, r) }) t.Run("test parallel", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() parallel := NewParallel() parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) chain.AppendParallel(parallel) _, err := chain.Compile(ctx) assert.Error(t, err) chain = NewChain[map[string]any, map[string]any]() parallel = NewParallel() parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) chain.AppendParallel(parallel) _, err = chain.Compile(ctx) assert.Error(t, err) chain = NewChain[map[string]any, map[string]any]() parallel = NewParallel() parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) parallel.AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) chain.AppendParallel(parallel) _, err = chain.Compile(ctx) assert.NoError(t, err) chain = NewChain[map[string]any, map[string]any]() parallel = NewParallel() parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) parallel.AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) chain.AppendParallel(parallel) parallel1 := NewParallel() parallel1.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) parallel1.AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) chain.AppendParallel(parallel1) _, err = chain.Compile(ctx) assert.Error(t, err) }) t.Run("test tools Node", func(t *testing.T) { ctx := context.Background() chain := NewChain[map[string]any, map[string]any]() toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{}) assert.NoError(t, err) chain.AppendToolsNode(toolsNode) }) t.Run("test chain with compile option", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) r, err := chain.Compile(ctx, WithMaxRunSteps(10)) assert.NoError(t, err) assert.NotNil(t, r) }) t.Run("test chain return type", func(t *testing.T) { t.Run("test chain any output type", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (any, error) { return 1, nil })) _, err := chain.Compile(ctx) assert.Nil(t, err) }) t.Run("test chain error output type", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { return "123", nil })) _, err := chain.Compile(ctx) assert.Error(t, err) }) t.Run("test chain error input type", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, input string) (map[string]any, error) { return nil, nil })) _, err := chain.Compile(ctx) assert.Error(t, err) }) }) } func TestParallelMultiNodes(t *testing.T) { ctx := context.Background() p := NewParallel() p.AddLambda("lambda", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) p.AddGraph("graph", NewChain[map[string]any, map[string]any]()) p.AddIndexer("indexer", indexer.NewMockIndexer(gomock.NewController(t))) p.AddLoader("loader", document.NewMockLoader(gomock.NewController(t))) p.AddDocumentTransformer("document transformer", document.NewMockTransformer(gomock.NewController(t))) p.AddRetriever("retriever", retriever.NewMockRetriever(gomock.NewController(t))) p.AddChatModel("chatmodel", model.NewMockChatModel(gomock.NewController(t))) p.AddChatTemplate("chatTemplate", prompt.FromMessages(schema.FString, schema.SystemMessage("hello"))) p.AddEmbedding("embedding", embedding.NewMockEmbedder(gomock.NewController(t))) p.AddPassthrough("passthrough") toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{}) assert.NoError(t, err) p.AddToolsNode("tools", toolsNode) assert.Greater(t, len(p.nodes), 6) ctrl := gomock.NewController(t) p = NewParallel() p.AddIndexer("key", indexer.NewMockIndexer(ctrl)) p.AddLoader("key", document.NewMockLoader(ctrl)) p.AddRetriever("r", retriever.NewMockRetriever(ctrl)) assert.NotNil(t, p.err) p = NewParallel() p.addNode("k", nil, nil) assert.NotNil(t, p.err) p = &Parallel{ outputKeys: nil, } p.addNode("k", &graphNode{}, nil) assert.NotNil(t, p.err) } type FakeLambdaOptions struct { Info string } type FakeLambdaOption func(opt *FakeLambdaOptions) func FakeWithLambdaInfo(info string) FakeLambdaOption { return func(opt *FakeLambdaOptions) { opt.Info = info } } func TestChainWithNodeKey(t *testing.T) { ctx := context.Background() t.Run("test normal chain with node key option", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }), WithNodeKey("lambda_01")) b := NewChainBranch(func(ctx context.Context, input map[string]any) (string, error) { return "lambda_02", nil }) b.AddLambda("lambda_02", InvokableLambdaWithOption(func(ctx context.Context, kvs map[string]any, opts ...FakeLambdaOption) (map[string]any, error) { opt := &FakeLambdaOptions{} for _, optFn := range opts { optFn(opt) } kvs["lambda_02"] = opt.Info return kvs, nil }), WithNodeKey("lambda_02")) b.AddLambda("lambda_03", InvokableLambdaWithOption(func(ctx context.Context, kvs map[string]any, opts ...FakeLambdaOption) (map[string]any, error) { opt := &FakeLambdaOptions{} for _, optFn := range opts { optFn(opt) } kvs["lambda_03"] = opt.Info return kvs, nil }), WithNodeKey("lambda_03")) chain.AppendBranch(b) chain.AppendPassthrough() p := NewParallel() p.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { return kvs["lambda_02"].(string), nil })) p.AddLambda("lambda_04", InvokableLambdaWithOption(func(ctx context.Context, kvs map[string]any, opts ...FakeLambdaOption) (string, error) { opt := &FakeLambdaOptions{} for _, optFn := range opts { optFn(opt) } return opt.Info, nil }), WithNodeKey("lambda_04")) p.AddLambda("lambda_05", InvokableLambdaWithOption(func(ctx context.Context, kvs map[string]any, opts ...FakeLambdaOption) (string, error) { opt := &FakeLambdaOptions{} for _, optFn := range opts { optFn(opt) } return opt.Info, nil }), WithNodeKey("lambda_05")) chain.AppendParallel(p) chain.AppendLambda(InvokableLambdaWithOption(func(ctx context.Context, kvs map[string]any, opts ...FakeLambdaOption) (map[string]any, error) { opt := &FakeLambdaOptions{} for _, optFn := range opts { optFn(opt) } kvs["lambda_06"] = opt.Info return kvs, nil }), WithNodeKey("lambda_06")) r, err := chain.Compile(ctx) assert.Nil(t, err) res, err := r.Invoke(ctx, map[string]any{}, WithLambdaOption(FakeWithLambdaInfo("normal")), WithLambdaOption(FakeWithLambdaInfo("info_lambda_02")).DesignateNode("lambda_02"), // branch WithLambdaOption(FakeWithLambdaInfo("info_lambda_03")).DesignateNode("lambda_03"), // branch (wont run) WithLambdaOption(FakeWithLambdaInfo("info_lambda_05")).DesignateNode("lambda_05"), // parallel ) assert.Nil(t, err) assert.Equal(t, "info_lambda_02", res["lambda_02"]) // transmit option with DesigateNode assert.Equal(t, "info_lambda_05", res["lambda_05"]) // transmit option with DesigateNode assert.Equal(t, "normal", res["lambda_06"]) // without DesigateNode, using default option }) t.Run("test chain with node key option and error with correct error info", func(t *testing.T) { t.Run("compile error of chain", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { return "123", nil }), WithNodeKey("lambda_01")) c, err := chain.Compile(ctx) assert.Nil(t, c) fmt.Printf("%+v\n", err) assert.Contains(t, err.Error(), "edge[lambda_01]") }) t.Run("compile error of branch", func(t *testing.T) { t.Run("without node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }), WithNodeKey("lambda_01")) b := NewChainBranch(func(ctx context.Context, input map[string]any) (string, error) { return "lambda_02", nil }) b.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) b.AddLambda("lambda_03", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { return "", nil })) chain.AppendBranch(b) c, err := chain.Compile(ctx) assert.Nil(t, c) fmt.Printf("%+v\n", err) assert.Contains(t, err.Error(), "edge[node_1_branch_lambda_03]") // with no node key option, will use default node key }) t.Run("with node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }), WithNodeKey("lambda_01")) b := NewChainBranch(func(ctx context.Context, input map[string]any) (string, error) { return "lambda_02", nil }) b.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }), WithNodeKey("lambda_02")) b.AddLambda("lambda_03", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { return "123", nil }), WithNodeKey("key_of_lambda_03")) chain.AppendBranch(b) c, err := chain.Compile(ctx) assert.Nil(t, c) fmt.Printf("%+v\n", err) assert.Contains(t, err.Error(), "edge[key_of_lambda_03]") }) }) t.Run("compile error of parallel", func(t *testing.T) { t.Run("without node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }), WithNodeKey("lambda_01")) p := NewParallel() p.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil })) p.AddLambda("lambda_03", InvokableLambda(func(ctx context.Context, v string) (string, error) { return "", nil })) chain.AppendParallel(p) c, err := chain.Compile(ctx) assert.Nil(t, c) fmt.Printf("%+v\n", err) assert.Contains(t, err.Error(), "to=node_1_parallel_1") // with no node key option, will use default node key }) t.Run("with node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }), WithNodeKey("lambda_01")) p := NewParallel() p.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return kvs, nil }), WithNodeKey("lambda_02")) p.AddLambda("lambda_03", InvokableLambda(func(ctx context.Context, v string) (string, error) { return "", nil }), WithNodeKey("key_of_lambda_03")) chain.AppendParallel(p) c, err := chain.Compile(ctx) assert.Nil(t, c) fmt.Printf("%+v\n", err) assert.Contains(t, err.Error(), "to=key_of_lambda_03") }) }) t.Run("invoke error", func(t *testing.T) { t.Run("branch with out node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() b := NewChainBranch(func(ctx context.Context, input map[string]any) (string, error) { return "lambda_01", nil }) b.AddLambda("lambda_01", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, fmt.Errorf("fake error") })) b.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, nil })) chain.AppendBranch(b) c, err := chain.Compile(ctx) assert.Nil(t, err) _, err = c.Invoke(ctx, map[string]any{}) fmt.Printf("%+v\n", err) assert.Error(t, err) assert.Contains(t, err.Error(), "node_0_branch_lambda_01") // with no node key option, will use default node key }) t.Run("branch with node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() b := NewChainBranch(func(ctx context.Context, input map[string]any) (string, error) { return "lambda_01", nil }) b.AddLambda("lambda_01", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, fmt.Errorf("fake error") }), WithNodeKey("key_of_lambda_01")) b.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, nil })) chain.AppendBranch(b) c, err := chain.Compile(ctx) assert.Nil(t, err) _, err = c.Invoke(ctx, map[string]any{}) fmt.Printf("%+v\n", err) assert.Error(t, err) assert.Contains(t, err.Error(), "key_of_lambda_01") }) t.Run("parallel with out node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() p := NewParallel() p.AddLambda("lambda_01", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, fmt.Errorf("fake error") })) p.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, nil })) chain.AppendParallel(p) c, err := chain.Compile(ctx) assert.Nil(t, err) _, err = c.Invoke(ctx, map[string]any{}) fmt.Printf("%+v\n", err) assert.Error(t, err) assert.Contains(t, err.Error(), "node_0_parallel_0") // with no node key option, will use default node key }) t.Run("parallel with node key", func(t *testing.T) { chain := NewChain[map[string]any, map[string]any]() p := NewParallel() p.AddLambda("lambda_01", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, fmt.Errorf("fake error") }), WithNodeKey("key_of_lambda_01")) p.AddLambda("lambda_02", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { return nil, nil })) chain.AppendParallel(p) c, err := chain.Compile(ctx) assert.Nil(t, err) _, err = c.Invoke(ctx, map[string]any{}) fmt.Printf("%+v\n", err) assert.Error(t, err) assert.Contains(t, err.Error(), "key_of_lambda_01") }) }) }) } ================================================ FILE: compose/checkpoint.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/internal/serialization" "github.com/cloudwego/eino/schema" ) func init() { schema.RegisterName[*checkpoint]("_eino_checkpoint") schema.RegisterName[*dagChannel]("_eino_dag_channel") schema.RegisterName[*pregelChannel]("_eino_pregel_channel") schema.RegisterName[dependencyState]("_eino_dependency_state") _ = serialization.GenericRegister[channel]("_eino_channel") } // RegisterSerializableType registers a custom type for eino serialization. // This allows eino to properly serialize and deserialize custom types. // Both custom interfaces and structs need to be registered using this function. // Types only need to be registered once - pointers and other references will be handled automatically. // All built-in eino types are already registered. // Parameters: // - name: A unique identifier for the type being registered (should not start with "_eino") // - T: The generic type parameter representing the type to register // Returns: // - error: An error if registration fails (e.g., if the type is already registered) // Deprecated: RegisterSerializableType is deprecated. Use schema.RegisterName[T](name) instead. func RegisterSerializableType[T any](name string) (err error) { return serialization.GenericRegister[T](name) } type CheckPointStore = core.CheckPointStore type Serializer interface { Marshal(v any) ([]byte, error) Unmarshal(data []byte, v any) error } // WithCheckPointStore sets the checkpoint store implementation for a graph. func WithCheckPointStore(store CheckPointStore) GraphCompileOption { return func(o *graphCompileOptions) { o.checkPointStore = store } } // WithSerializer sets the serializer used to persist checkpoint state. func WithSerializer(serializer Serializer) GraphCompileOption { return func(o *graphCompileOptions) { o.serializer = serializer } } // WithCheckPointID sets the checkpoint ID to load from and write to by default. func WithCheckPointID(checkPointID string) Option { return Option{ checkPointID: &checkPointID, } } // WithWriteToCheckPointID specifies a different checkpoint ID to write to. // If not provided, the checkpoint ID from WithCheckPointID will be used for writing. // This is useful for scenarios where you want to load from an existed checkpoint // but save the progress to a new, separate checkpoint. func WithWriteToCheckPointID(checkPointID string) Option { return Option{ writeToCheckPointID: &checkPointID, } } // WithForceNewRun forces the graph to run from the beginning, ignoring any checkpoints. func WithForceNewRun() Option { return Option{ forceNewRun: true, } } // StateModifier modifies state during checkpoint operations for a given node path. type StateModifier func(ctx context.Context, path NodePath, state any) error // WithStateModifier installs a state modifier invoked during checkpoint read/write. func WithStateModifier(sm StateModifier) Option { return Option{ stateModifier: sm, } } type checkpoint struct { Channels map[string]channel Inputs map[string] /*node key*/ any /*input*/ State any SkipPreHandler map[string]bool RerunNodes []string SubGraphs map[string]*checkpoint InterruptID2Addr map[string]Address InterruptID2State map[string]core.InterruptState } type stateModifierKey struct{} type checkPointKey struct{} // *checkpoint func getStateModifier(ctx context.Context) StateModifier { if sm, ok := ctx.Value(stateModifierKey{}).(StateModifier); ok { return sm } return nil } func setStateModifier(ctx context.Context, modifier StateModifier) context.Context { return context.WithValue(ctx, stateModifierKey{}, modifier) } func getCheckPointFromStore(ctx context.Context, id string, cpr *checkPointer) (cp *checkpoint, err error) { cp, existed, err := cpr.get(ctx, id) if err != nil { return nil, err } if !existed { return nil, nil } return cp, nil } func setCheckPointToCtx(ctx context.Context, cp *checkpoint) context.Context { ctx = core.PopulateInterruptState(ctx, cp.InterruptID2Addr, cp.InterruptID2State) return context.WithValue(ctx, checkPointKey{}, cp) } func getCheckPointFromCtx(ctx context.Context) *checkpoint { if cp, ok := ctx.Value(checkPointKey{}).(*checkpoint); ok { return cp } return nil } func forwardCheckPoint(ctx context.Context, nodeKey string) context.Context { cp := getCheckPointFromCtx(ctx) if cp == nil { return ctx } if subCP, ok := cp.SubGraphs[nodeKey]; ok { delete(cp.SubGraphs, nodeKey) // only forward once return context.WithValue(ctx, checkPointKey{}, subCP) } return context.WithValue(ctx, checkPointKey{}, (*checkpoint)(nil)) } func newCheckPointer( inputPairs, outputPairs map[string]streamConvertPair, store CheckPointStore, serializer Serializer, ) *checkPointer { if serializer == nil { serializer = &serialization.InternalSerializer{} } return &checkPointer{ sc: newStreamConverter(inputPairs, outputPairs), store: store, serializer: serializer, } } type checkPointer struct { sc *streamConverter store CheckPointStore serializer Serializer } func (c *checkPointer) get(ctx context.Context, id string) (*checkpoint, bool, error) { data, existed, err := c.store.Get(ctx, id) if err != nil || existed == false { return nil, existed, err } cp := &checkpoint{} err = c.serializer.Unmarshal(data, cp) if err != nil { return nil, false, err } return cp, true, nil } func (c *checkPointer) set(ctx context.Context, id string, cp *checkpoint) error { data, err := c.serializer.Marshal(cp) if err != nil { return err } return c.store.Set(ctx, id, data) } // MigrateCheckpointState is an advanced compatibility utility for checkpoint upgrades. // // It decodes checkpoint bytes using the given serializer, applies migrate to checkpoint.State and // all nested SubGraphs' states, then re-encodes the checkpoint. // // Typical use cases: // - Resume-time migration when you changed your graph state type/schema and need to load old // checkpoints without discarding them. // - Framework-level backward compatibility (e.g. ADK upgrading checkpoints across versions). // // Migrate callback contract: // - Returns (newState, changed, error). // - If changed is false, the state is left as-is. // - If error is non-nil, migration stops and the error is returned to the caller. // // The original bytes are returned only if no state was changed anywhere in the checkpoint tree. func MigrateCheckpointState(data []byte, serializer Serializer, migrate func(state any) (any, bool, error)) ([]byte, error) { cp := &checkpoint{} if err := serializer.Unmarshal(data, cp); err != nil { return nil, err } changed, err := migrateCheckpoint(cp, migrate) if err != nil { return nil, err } if !changed { return data, nil } return serializer.Marshal(cp) } // migrateCheckpoint recursively applies migrate to cp.State and all SubGraphs. func migrateCheckpoint(cp *checkpoint, migrate func(state any) (any, bool, error)) (bool, error) { anyChanged := false if cp.State != nil { newState, changed, err := migrate(cp.State) if err != nil { return false, err } if changed { cp.State = newState anyChanged = true } } for _, sub := range cp.SubGraphs { changed, err := migrateCheckpoint(sub, migrate) if err != nil { return false, err } if changed { anyChanged = true } } return anyChanged, nil } // convertCheckPoint if value in checkpoint is streamReader, convert it to non-stream func (c *checkPointer) convertCheckPoint(cp *checkpoint, isStream bool) (err error) { for _, ch := range cp.Channels { err = ch.convertValues(func(m map[string]any) error { return c.sc.convertOutputs(isStream, m) }) if err != nil { return err } } err = c.sc.convertInputs(isStream, cp.Inputs) if err != nil { return err } return nil } // convertCheckPoint convert values in checkpoint to streamReader if needed func (c *checkPointer) restoreCheckPoint(cp *checkpoint, isStream bool) (err error) { for _, ch := range cp.Channels { err = ch.convertValues(func(m map[string]any) error { return c.sc.restoreOutputs(isStream, m) }) if err != nil { return err } } err = c.sc.restoreInputs(isStream, cp.Inputs) if err != nil { return err } return nil } func newStreamConverter(inputPairs, outputPairs map[string]streamConvertPair) *streamConverter { return &streamConverter{ inputPairs: inputPairs, outputPairs: outputPairs, } } type streamConverter struct { inputPairs, outputPairs map[string]streamConvertPair } func (s *streamConverter) convertInputs(isStream bool, values map[string]any) error { return convert(values, s.inputPairs, isStream) } func (s *streamConverter) restoreInputs(isStream bool, values map[string]any) error { return restore(values, s.inputPairs, isStream) } func (s *streamConverter) convertOutputs(isStream bool, values map[string]any) error { return convert(values, s.outputPairs, isStream) } func (s *streamConverter) restoreOutputs(isStream bool, values map[string]any) error { return restore(values, s.outputPairs, isStream) } func convert(values map[string]any, convPairs map[string]streamConvertPair, isStream bool) error { if !isStream { return nil } for key, v := range values { convPair, ok := convPairs[key] if !ok { return fmt.Errorf("checkpoint conv stream fail, node[%s] have not been registered", key) } sr, ok := v.(streamReader) if !ok { return fmt.Errorf("checkpoint conv stream fail, value of [%s] isn't stream", key) } nValue, err := convPair.concatStream(sr) if err != nil { return err } values[key] = nValue } return nil } func restore(values map[string]any, convPairs map[string]streamConvertPair, isStream bool) error { if !isStream { return nil } for key, v := range values { convPair, ok := convPairs[key] if !ok { return fmt.Errorf("checkpoint restore stream fail, node[%s] have not been registered", key) } sr, err := convPair.restoreStream(v) if err != nil { return err } values[key] = sr } return nil } ================================================ FILE: compose/checkpoint_migrate_test.go ================================================ /* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "errors" "testing" "github.com/stretchr/testify/assert" ) type stubSerializer struct { unmarshal func(data []byte, v any) error marshal func(v any) ([]byte, error) } func (s stubSerializer) Marshal(v any) ([]byte, error) { return s.marshal(v) } func (s stubSerializer) Unmarshal(data []byte, v any) error { return s.unmarshal(data, v) } func TestMigrateCheckpointState_UnmarshalError(t *testing.T) { in := []byte("in") codec := stubSerializer{ unmarshal: func(_ []byte, _ any) error { return errors.New("bad") }, marshal: func(_ any) ([]byte, error) { return []byte("unused"), nil }, } _, err := MigrateCheckpointState(in, codec, func(state any) (any, bool, error) { return state, false, nil }) assert.Error(t, err) } func TestMigrateCheckpointState_NoChangeReturnsOriginalBytes(t *testing.T) { in := []byte("in") cp := &checkpoint{State: "s"} codec := stubSerializer{ unmarshal: func(_ []byte, v any) error { *(v.(*checkpoint)) = *cp return nil }, marshal: func(_ any) ([]byte, error) { return []byte("marshaled"), nil }, } out, err := MigrateCheckpointState(in, codec, func(state any) (any, bool, error) { return state, false, nil }) assert.NoError(t, err) assert.Equal(t, in, out) } func TestMigrateCheckpointState_ChangeTriggersMarshal(t *testing.T) { in := []byte("in") cp := &checkpoint{State: "s"} var sawState any codec := stubSerializer{ unmarshal: func(_ []byte, v any) error { *(v.(*checkpoint)) = *cp return nil }, marshal: func(v any) ([]byte, error) { sawState = v.(*checkpoint).State return []byte("marshaled"), nil }, } out, err := MigrateCheckpointState(in, codec, func(state any) (any, bool, error) { return "s2", true, nil }) assert.NoError(t, err) assert.Equal(t, []byte("marshaled"), out) assert.Equal(t, "s2", sawState) } func TestMigrateCheckpointState_MigrateErrorStops(t *testing.T) { in := []byte("in") cp := &checkpoint{ State: "root", SubGraphs: map[string]*checkpoint{ "sub": {State: "sub"}, }, } codec := stubSerializer{ unmarshal: func(_ []byte, v any) error { *(v.(*checkpoint)) = *cp return nil }, marshal: func(_ any) ([]byte, error) { return []byte("marshaled"), nil }, } _, err := MigrateCheckpointState(in, codec, func(state any) (any, bool, error) { if state == "sub" { return nil, false, errors.New("boom") } return state, false, nil }) assert.Error(t, err) } ================================================ FILE: compose/checkpoint_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "io" "sync" "testing" "time" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/internal/callbacks" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/serialization" "github.com/cloudwego/eino/schema" ) type inMemoryStore struct { m map[string][]byte } func (i *inMemoryStore) Get(_ context.Context, checkPointID string) ([]byte, bool, error) { v, ok := i.m[checkPointID] return v, ok, nil } func (i *inMemoryStore) Set(_ context.Context, checkPointID string, checkPoint []byte) error { i.m[checkPointID] = checkPoint return nil } func newInMemoryStore() *inMemoryStore { return &inMemoryStore{ m: make(map[string][]byte), } } type testStruct struct { A string } func init() { schema.Register[testStruct]() } func TestSimpleCheckPoint(t *testing.T) { store := newInMemoryStore() g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testStruct) { return &testStruct{A: ""} })) err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) assert.NoError(t, err) err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "2", nil }), WithStatePreHandler(func(ctx context.Context, in string, state *testStruct) (string, error) { return in + state.A, nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(store), WithInterruptAfterNodes([]string{"1"}), WithInterruptBeforeNodes([]string{"2"}), WithGraphName("root")) assert.NoError(t, err) _, err = r.Invoke(ctx, "start", WithCheckPointID("1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, &testStruct{A: ""}, info.State) assert.Equal(t, []string{"2"}, info.BeforeNodes) assert.Equal(t, []string{"1"}, info.AfterNodes) assert.Empty(t, info.RerunNodesExtra) assert.Empty(t, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, })) rCtx := ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) result, err := r.Invoke(rCtx, "start", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "start1state2", result) /* _, err = r.Stream(ctx, "start", WithCheckPointID("2")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, &testStruct{A: ""}, info.State) assert.Equal(t, []string{"2"}, info.BeforeNodes) assert.Equal(t, []string{"1"}, info.AfterNodes) assert.Empty(t, info.RerunNodesExtra) assert.Empty(t, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) streamResult, err := r.Stream(rCtx, "start", WithCheckPointID("2")) assert.NoError(t, err) result = "" for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) result += chunk } assert.Equal(t, "start1state2", result)*/ } func TestCustomStructInAn2y(t *testing.T) { store := newInMemoryStore() g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testStruct) { return &testStruct{A: ""} })) err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output *testStruct, err error) { return &testStruct{A: input + "1"}, nil }), WithOutputKey("1")) assert.NoError(t, err) err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { return input["1"].(*testStruct).A + "2", nil }), WithStatePreHandler(func(ctx context.Context, in map[string]any, state *testStruct) (map[string]any, error) { in["1"].(*testStruct).A += state.A return in, nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(store), WithInterruptAfterNodes([]string{"1"}), WithGraphName("root")) assert.NoError(t, err) _, err = r.Invoke(ctx, "start", WithCheckPointID("1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, &testStruct{A: ""}, info.State) assert.Equal(t, []string{"1"}, info.AfterNodes) assert.Empty(t, info.RerunNodesExtra) assert.Empty(t, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, })) rCtx := ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) result, err := r.Invoke(rCtx, "start", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "start1state2", result) _, err = r.Stream(ctx, "start", WithCheckPointID("2")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, &testStruct{A: ""}, info.State) assert.Equal(t, []string{"1"}, info.AfterNodes) assert.Empty(t, info.RerunNodesExtra) assert.Empty(t, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) streamResult, err := r.Stream(rCtx, "start", WithCheckPointID("2")) assert.NoError(t, err) result = "" for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) result += chunk } assert.Equal(t, "start1state2", result) } func TestSubGraph(t *testing.T) { subG := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testStruct) { return &testStruct{A: ""} })) err := subG.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) assert.NoError(t, err) err = subG.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "2", nil }), WithStatePreHandler(func(ctx context.Context, in string, state *testStruct) (string, error) { return in + state.A, nil })) assert.NoError(t, err) err = subG.AddEdge(START, "1") assert.NoError(t, err) err = subG.AddEdge("1", "2") assert.NoError(t, err) err = subG.AddEdge("2", END) assert.NoError(t, err) g := NewGraph[string, string]() err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) assert.NoError(t, err) err = g.AddGraphNode("2", subG, WithGraphCompileOptions(WithInterruptAfterNodes([]string{"1"}))) assert.NoError(t, err) err = g.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "3", nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge("2", "3") assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore()), WithGraphName("root")) assert.NoError(t, err) _, err = r.Invoke(ctx, "start", WithCheckPointID("1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx := ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) result, err := r.Invoke(rCtx, "start", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "start11state23", result) _, err = r.Stream(ctx, "start", WithCheckPointID("2")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]any), SubGraphs: map[string]*InterruptInfo{}, }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) streamResult, err := r.Stream(rCtx, "start", WithCheckPointID("2")) assert.NoError(t, err) result = "" for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) result += chunk } assert.Equal(t, "start11state23", result) } type testGraphCallback struct { onStartTimes int onEndTimes int onStreamStartTimes int onStreamEndTimes int onErrorTimes int } func (t *testGraphCallback) OnStart(ctx context.Context, info *callbacks.RunInfo, _ callbacks.CallbackInput) context.Context { if info.Component == ComponentOfGraph { t.onStartTimes++ } return ctx } func (t *testGraphCallback) OnEnd(ctx context.Context, info *callbacks.RunInfo, _ callbacks.CallbackOutput) context.Context { if info.Component == ComponentOfGraph { t.onEndTimes++ } return ctx } func (t *testGraphCallback) OnError(ctx context.Context, info *callbacks.RunInfo, _ error) context.Context { if info.Component == ComponentOfGraph { t.onErrorTimes++ } return ctx } func (t *testGraphCallback) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { input.Close() if info.Component == ComponentOfGraph { t.onStreamStartTimes++ } return ctx } func (t *testGraphCallback) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { output.Close() if info.Component == ComponentOfGraph { t.onStreamEndTimes++ } return ctx } func TestNestedSubGraph(t *testing.T) { sSubG := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testStruct) { return &testStruct{A: ""} })) err := sSubG.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) assert.NoError(t, err) err = sSubG.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "2", nil }), WithStatePreHandler(func(ctx context.Context, in string, state *testStruct) (string, error) { return in + state.A, nil })) assert.NoError(t, err) err = sSubG.AddEdge(START, "1") assert.NoError(t, err) err = sSubG.AddEdge("1", "2") assert.NoError(t, err) err = sSubG.AddEdge("2", END) assert.NoError(t, err) subG := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testStruct) { return &testStruct{A: ""} })) err = subG.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) assert.NoError(t, err) err = subG.AddGraphNode("2", sSubG, WithGraphCompileOptions(WithInterruptAfterNodes([]string{"1"})), WithStatePreHandler(func(ctx context.Context, in string, state *testStruct) (string, error) { return in + state.A, nil }), WithOutputKey("2")) assert.NoError(t, err) err = subG.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "3", nil }), WithOutputKey("3")) assert.NoError(t, err) err = subG.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { return input["2"].(string) + "4\n" + input["3"].(string) + "4\n" + input["state"].(string) + "4\n", nil }), WithStatePreHandler(func(ctx context.Context, in map[string]any, state *testStruct) (map[string]any, error) { in["state"] = state.A return in, nil })) assert.NoError(t, err) err = subG.AddEdge(START, "1") assert.NoError(t, err) err = subG.AddEdge("1", "2") assert.NoError(t, err) err = subG.AddEdge("1", "3") assert.NoError(t, err) err = subG.AddEdge("3", "4") assert.NoError(t, err) err = subG.AddEdge("2", "4") assert.NoError(t, err) err = subG.AddEdge("4", END) assert.NoError(t, err) g := NewGraph[string, string]() err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) assert.NoError(t, err) err = g.AddGraphNode("2", subG, WithGraphCompileOptions(WithInterruptAfterNodes([]string{"1", "3"}), WithInterruptBeforeNodes([]string{"4"}))) assert.NoError(t, err) err = g.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "3", nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge("2", "3") assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore()), WithGraphName("root")) assert.NoError(t, err) tGCB := &testGraphCallback{} _, err = r.Invoke(ctx, "start", WithCheckPointID("1"), WithCallbacks(tGCB)) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx := ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Invoke(rCtx, "start", WithCheckPointID("1"), WithCallbacks(tGCB)) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, AfterNodes: []string{"3"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, Parent: &InterruptCtx{ ID: "runnable:root", Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Invoke(rCtx, "start", WithCheckPointID("1"), WithCallbacks(tGCB)) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, BeforeNodes: []string{"4"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state2"}) result, err := r.Invoke(rCtx, "start", WithCheckPointID("1"), WithCallbacks(tGCB)) assert.NoError(t, err) assert.Equal(t, `start11state1state24 start1134 state24 3`, result) _, err = r.Stream(ctx, "start", WithCheckPointID("2"), WithCallbacks(tGCB)) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Stream(rCtx, "start", WithCheckPointID("2"), WithCallbacks(tGCB)) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, AfterNodes: []string{"3"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Stream(rCtx, "start", WithCheckPointID("2"), WithCallbacks(tGCB)) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, BeforeNodes: []string{"4"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state2"}) streamResult, err := r.Stream(rCtx, "start", WithCheckPointID("2"), WithCallbacks(tGCB)) assert.NoError(t, err) result = "" for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) result += chunk } assert.Equal(t, `start11state1state24 start1134 state24 3`, result) assert.Equal(t, 10, tGCB.onStartTimes) // 3+sSubG*1*3+subG*2*2+g*0 assert.Equal(t, 3, tGCB.onEndTimes) // success*3 assert.Equal(t, 10, tGCB.onStreamStartTimes) // 3+sSubG*1*3+subG*2*2+g*0 assert.Equal(t, 3, tGCB.onStreamEndTimes) // success*3 assert.Equal(t, 14, tGCB.onErrorTimes) // 2*(sSubG*1*3+subG*2*2+g*0) // dag r, err = g.Compile(ctx, WithCheckPointStore(newInMemoryStore()), WithNodeTriggerMode(AllPredecessor), WithGraphName("root")) assert.NoError(t, err) _, err = r.Invoke(ctx, "start", WithCheckPointID("1")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Invoke(rCtx, "start", WithCheckPointID("1")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, AfterNodes: []string{"3"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ ID: "runnable:root;node:2;node:2", Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Invoke(rCtx, "start", WithCheckPointID("1")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, BeforeNodes: []string{"4"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state2"}) result, err = r.Invoke(rCtx, "start", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, `start11state1state24 start1134 state24 3`, result) _, err = r.Stream(ctx, "start", WithCheckPointID("2")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Stream(rCtx, "start", WithCheckPointID("2")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, AfterNodes: []string{"3"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: map[string]*InterruptInfo{ "2": { State: &testStruct{A: ""}, AfterNodes: []string{"1"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state"}) _, err = r.Stream(rCtx, "start", WithCheckPointID("2")) assert.NotNil(t, err) info, ok = ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, map[string]*InterruptInfo{ "2": { State: &testStruct{A: "state"}, BeforeNodes: []string{"4"}, RerunNodesExtra: make(map[string]interface{}), SubGraphs: make(map[string]*InterruptInfo), }, }, info.SubGraphs) assert.True(t, info.InterruptContexts[0].EqualsWithoutID(&InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, { Type: AddressSegmentNode, ID: "2", }, }, Info: &testStruct{ A: "state", }, IsRootCause: true, Parent: &InterruptCtx{ Address: Address{ { Type: AddressSegmentRunnable, ID: "root", }, }, }, })) rCtx = ResumeWithData(ctx, info.InterruptContexts[0].ID, &testStruct{A: "state2"}) streamResult, err = r.Stream(rCtx, "start", WithCheckPointID("2")) assert.NoError(t, err) result = "" for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) result += chunk } assert.Equal(t, `start11state1state24 start1134 state24 3`, result) } func TestDAGInterrupt(t *testing.T) { g := NewGraph[string, map[string]any]() err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { time.Sleep(time.Millisecond * 100) return input, nil }), WithOutputKey("1")) assert.NoError(t, err) err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { time.Sleep(time.Millisecond * 200) return input, nil }), WithOutputKey("2")) assert.NoError(t, err) err = g.AddPassthroughNode("3") assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge(START, "2") assert.NoError(t, err) err = g.AddEdge("1", "3") assert.NoError(t, err) err = g.AddEdge("2", "3") assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore()), WithInterruptAfterNodes([]string{"1", "2"})) assert.NoError(t, err) _, err = r.Invoke(ctx, "input", WithCheckPointID("1")) info, existed := ExtractInterruptInfo(err) assert.True(t, existed) assert.Equal(t, []string{"1", "2"}, info.AfterNodes) result, err := r.Invoke(ctx, "", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, map[string]any{"1": "input", "2": "input"}, result) } func TestRerunNodeInterrupt(t *testing.T) { g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testStruct) { return &testStruct{} })) times := 0 err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { defer func() { times++ }() if times%2 == 0 { return "", NewInterruptAndRerunErr("test extra") } return input, nil }), WithStatePreHandler(func(ctx context.Context, in string, state *testStruct) (string, error) { return state.A, nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) _, err = r.Invoke(ctx, "input", WithCheckPointID("1")) info, existed := ExtractInterruptInfo(err) assert.True(t, existed) assert.Equal(t, []string{"1"}, info.RerunNodes) result, err := r.Invoke(ctx, "", WithCheckPointID("1"), WithStateModifier(func(ctx context.Context, path NodePath, state any) error { state.(*testStruct).A = "state" return nil })) assert.NoError(t, err) assert.Equal(t, "state", result) _, err = r.Stream(ctx, "input", WithCheckPointID("2")) info, existed = ExtractInterruptInfo(err) assert.True(t, existed) assert.Equal(t, []string{"1"}, info.RerunNodes) assert.Equal(t, "test extra", info.RerunNodesExtra["1"].(string)) streamResult, err := r.Stream(ctx, "", WithCheckPointID("2"), WithStateModifier(func(ctx context.Context, path NodePath, state any) error { state.(*testStruct).A = "state" return nil })) assert.NoError(t, err) chunk, err := streamResult.Recv() assert.NoError(t, err) assert.Equal(t, "state", chunk) _, err = streamResult.Recv() assert.Equal(t, io.EOF, err) } type myInterface interface { A() } func TestInterfaceResume(t *testing.T) { g := NewGraph[myInterface, string]() times := 0 assert.NoError(t, g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input myInterface) (output string, err error) { if times == 0 { times++ return "", NewInterruptAndRerunErr("test extra") } return "success", nil }))) assert.NoError(t, g.AddEdge(START, "1")) assert.NoError(t, g.AddEdge("1", END)) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) _, err = r.Invoke(ctx, nil, WithCheckPointID("1")) info, existed := ExtractInterruptInfo(err) assert.True(t, existed) assert.Equal(t, []string{"1"}, info.RerunNodes) result, err := r.Invoke(ctx, nil, WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "success", result) } func TestEarlyFailCallback(t *testing.T) { g := NewGraph[string, string]() assert.NoError(t, g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }))) assert.NoError(t, g.AddEdge(START, "1")) assert.NoError(t, g.AddEdge("1", END)) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) tGCB := &testGraphCallback{} _, _ = r.Invoke(ctx, "", WithCallbacks(tGCB), WithRuntimeMaxSteps(1)) assert.Equal(t, 1, tGCB.onStartTimes) assert.Equal(t, 1, tGCB.onErrorTimes) assert.Equal(t, 0, tGCB.onEndTimes) } func TestGraphStartInterrupt(t *testing.T) { subG := NewGraph[string, string]() _ = subG.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "sub1", nil })) _ = subG.AddEdge(START, "1") _ = subG.AddEdge("1", END) g := NewGraph[string, string]() _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) _ = g.AddGraphNode("2", subG, WithGraphCompileOptions(WithInterruptBeforeNodes([]string{"1"}))) _ = g.AddEdge(START, "1") _ = g.AddEdge("1", "2") _ = g.AddEdge("2", END) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) _, err = r.Invoke(ctx, "input", WithCheckPointID("1")) info, existed := ExtractInterruptInfo(err) assert.True(t, existed) assert.Equal(t, []string{"1"}, info.SubGraphs["2"].BeforeNodes) result, err := r.Invoke(ctx, "", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input1sub1", result) } func TestWithForceNewRun(t *testing.T) { g := NewGraph[string, string]() _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) _ = g.AddEdge(START, "1") _ = g.AddEdge("1", END) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(&failStore{t: t})) assert.NoError(t, err) result, err := r.Invoke(ctx, "input", WithCheckPointID("1"), WithForceNewRun()) assert.NoError(t, err) assert.Equal(t, "input1", result) } type failStore struct { t *testing.T } func (f *failStore) Get(_ context.Context, _ string) ([]byte, bool, error) { f.t.Fatalf("cannot call store") return nil, false, errors.New("fail") } func (f *failStore) Set(_ context.Context, _ string, _ []byte) error { f.t.Fatalf("cannot call store") return errors.New("fail") } func TestPreHandlerInterrupt(t *testing.T) { type state struct{} assert.NoError(t, serialization.GenericRegister[state]("_eino_TestPreHandlerInterrupt_state")) g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) state { return state{} })) times := 0 _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil }), WithStatePreHandler(func(ctx context.Context, in string, state state) (string, error) { if times == 0 { times++ return "", NewInterruptAndRerunErr("") } return in, nil })) _ = g.AddEdge(START, "1") _ = g.AddEdge("1", END) ctx := context.Background() r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) _, err = r.Invoke(ctx, "input", WithCheckPointID("1")) info, existed := ExtractInterruptInfo(err) assert.True(t, existed) assert.Equal(t, []string{"1"}, info.RerunNodes) result, err := r.Invoke(ctx, "", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "1", result) } func TestCancelInterrupt(t *testing.T) { g := NewGraph[string, string]() _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { time.Sleep(3 * time.Second) return input + "1", nil })) _ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "2", nil })) _ = g.AddEdge(START, "1") _ = g.AddEdge("1", "2") _ = g.AddEdge("2", END) ctx := context.Background() // pregel r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // interrupt after nodes canceledCtx, cancel := WithGraphInterrupt(ctx) go func() { time.Sleep(500 * time.Millisecond) cancel(WithGraphInterruptTimeout(time.Hour)) }() _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1")) assert.Error(t, err) info, success := ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) result, err := r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) // infinite timeout canceledCtx, cancel = WithGraphInterrupt(ctx) go func() { time.Sleep(500 * time.Millisecond) cancel() }() _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2")) assert.Error(t, err) info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) // interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved canceledCtx, cancel = WithGraphInterrupt(ctx) go func() { time.Sleep(500 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3")) assert.Error(t, err) info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) // dag g = NewGraph[string, string]() _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { time.Sleep(3 * time.Second) return input + "1", nil })) _ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "2", nil })) _ = g.AddEdge(START, "1") _ = g.AddEdge("1", "2") _ = g.AddEdge("2", END) r, err = g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // interrupt after nodes canceledCtx, cancel = WithGraphInterrupt(ctx) go func() { time.Sleep(500 * time.Millisecond) cancel(WithGraphInterruptTimeout(time.Hour)) }() _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1")) assert.Error(t, err) info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) result, err = r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) // infinite timeout canceledCtx, cancel = WithGraphInterrupt(ctx) go func() { time.Sleep(500 * time.Millisecond) cancel() }() _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2")) assert.Error(t, err) info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) // interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved canceledCtx, cancel = WithGraphInterrupt(ctx) go func() { time.Sleep(300 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3")) assert.Error(t, err) info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) // dag multi canceled nodes gg := NewGraph[string, map[string]any]() _ = gg.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "1", nil })) _ = gg.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { time.Sleep(3 * time.Second) return input + "2", nil }), WithOutputKey("2")) _ = gg.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { time.Sleep(3 * time.Second) return input + "3", nil }), WithOutputKey("3")) _ = gg.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) { return input, nil })) _ = gg.AddEdge(START, "1") _ = gg.AddEdge("1", "2") _ = gg.AddEdge("1", "3") _ = gg.AddEdge("2", "4") _ = gg.AddEdge("3", "4") _ = gg.AddEdge("4", END) ctx = context.Background() rr, err := gg.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // interrupt after nodes canceledCtx, cancel = WithGraphInterrupt(ctx) go func() { time.Sleep(500 * time.Millisecond) cancel(WithGraphInterruptTimeout(time.Hour)) }() _, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("1")) assert.Error(t, err) info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.AfterNodes)) result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, map[string]any{ "2": "input12", "3": "input13", }, result2) // interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved canceledCtx, cancel = WithGraphInterrupt(ctx) go func() { time.Sleep(500 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("2")) assert.Error(t, err) info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.RerunNodes)) result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, map[string]any{ "2": "input12", "3": "input13", }, result2) } func TestPersistRerunInputNonStream(t *testing.T) { store := newInMemoryStore() var mu sync.Mutex var receivedInput string var callCount int g := NewGraph[string, string]() err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { mu.Lock() callCount++ currentCount := callCount receivedInput = input mu.Unlock() if currentCount == 1 { time.Sleep(2 * time.Second) } return input + "_processed", nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(store), ) assert.NoError(t, err) canceledCtx, cancel := WithGraphInterrupt(ctx) go func() { time.Sleep(100 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = r.Invoke(canceledCtx, "test_input", WithCheckPointID("cp1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, []string{"1"}, info.RerunNodes) mu.Lock() assert.Equal(t, "test_input", receivedInput) mu.Unlock() result, err := r.Invoke(ctx, "", WithCheckPointID("cp1")) assert.NoError(t, err) assert.Equal(t, "test_input_processed", result) mu.Lock() assert.Equal(t, "test_input", receivedInput) assert.Equal(t, 2, callCount) mu.Unlock() } func TestPersistRerunInputStream(t *testing.T) { store := newInMemoryStore() var mu sync.Mutex var receivedInput string var callCount int g := NewGraph[string, string]() err := g.AddLambdaNode("1", TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) { mu.Lock() callCount++ currentCount := callCount mu.Unlock() var sb string for { chunk, err := input.Recv() if err == io.EOF { break } if err != nil { return nil, err } sb += chunk } mu.Lock() receivedInput = sb mu.Unlock() if currentCount == 1 { time.Sleep(2 * time.Second) } return schema.StreamReaderFromArray([]string{sb + "_processed"}), nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(store), ) assert.NoError(t, err) inputStream := schema.StreamReaderFromArray([]string{"chunk1", "chunk2", "chunk3"}) canceledCtx, cancel := WithGraphInterrupt(ctx) go func() { time.Sleep(100 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = r.Transform(canceledCtx, inputStream, WithCheckPointID("cp1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, []string{"1"}, info.RerunNodes) mu.Lock() assert.Equal(t, "chunk1chunk2chunk3", receivedInput) mu.Unlock() emptyInputStream := schema.StreamReaderFromArray([]string{}) resultStream, err := r.Transform(ctx, emptyInputStream, WithCheckPointID("cp1")) assert.NoError(t, err) var result string for { chunk, err := resultStream.Recv() if err == io.EOF { break } assert.NoError(t, err) result += chunk } assert.Equal(t, "chunk1chunk2chunk3_processed", result) mu.Lock() assert.Equal(t, "chunk1chunk2chunk3", receivedInput) assert.Equal(t, 2, callCount) mu.Unlock() } type testPersistRerunInputState struct { Prefix string } func TestPersistRerunInputWithPreHandler(t *testing.T) { store := newInMemoryStore() var mu sync.Mutex var receivedInput string var callCount int schema.Register[testPersistRerunInputState]() g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) *testPersistRerunInputState { return &testPersistRerunInputState{Prefix: "prefix_"} })) err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { mu.Lock() callCount++ currentCount := callCount receivedInput = input mu.Unlock() if currentCount == 1 { time.Sleep(2 * time.Second) } return input + "_processed", nil }), WithStatePreHandler(func(ctx context.Context, in string, s *testPersistRerunInputState) (string, error) { return s.Prefix + in, nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(store), ) assert.NoError(t, err) canceledCtx, cancel := WithGraphInterrupt(ctx) go func() { time.Sleep(100 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = r.Invoke(canceledCtx, "test_input", WithCheckPointID("cp1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) if ok { assert.Equal(t, []string{"1"}, info.RerunNodes) } mu.Lock() assert.Equal(t, "prefix_test_input", receivedInput) mu.Unlock() result, err := r.Invoke(ctx, "", WithCheckPointID("cp1")) assert.NoError(t, err) assert.Equal(t, "prefix_test_input_processed", result) mu.Lock() assert.Equal(t, "prefix_test_input", receivedInput) assert.Equal(t, 2, callCount) mu.Unlock() } func TestPersistRerunInputBackwardCompatibility(t *testing.T) { store := newInMemoryStore() var receivedInput string var callCount int g := NewGraph[string, string]() err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { callCount++ receivedInput = input if len(input) > 0 { return "", StatefulInterrupt(ctx, "interrupt", input) } _, _, restoredInput := GetInterruptState[string](ctx) return restoredInput + "_processed", nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(store), ) assert.NoError(t, err) _, err = r.Invoke(ctx, "test_input", WithCheckPointID("cp1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, []string{"1"}, info.RerunNodes) assert.Equal(t, "test_input", receivedInput) result, err := r.Invoke(ctx, "", WithCheckPointID("cp1")) assert.NoError(t, err) assert.Equal(t, "test_input_processed", result) assert.Equal(t, "", receivedInput) assert.Equal(t, 2, callCount) } func TestPersistRerunInputSubGraph(t *testing.T) { store := newInMemoryStore() var mu sync.Mutex var receivedInput string var callCount int subG := NewGraph[string, string]() err := subG.AddLambdaNode("sub1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { mu.Lock() callCount++ currentCount := callCount receivedInput = input mu.Unlock() if currentCount == 1 { time.Sleep(2 * time.Second) } return input + "_sub_processed", nil })) assert.NoError(t, err) err = subG.AddEdge(START, "sub1") assert.NoError(t, err) err = subG.AddEdge("sub1", END) assert.NoError(t, err) g := NewGraph[string, string]() err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + "_main", nil })) assert.NoError(t, err) err = g.AddGraphNode("2", subG) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(store), ) assert.NoError(t, err) canceledCtx, cancel := WithGraphInterrupt(ctx) go func() { time.Sleep(100 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = r.Invoke(canceledCtx, "test", WithCheckPointID("cp1")) assert.NotNil(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok, "Expected interrupt error, got: %v", err) if len(info.SubGraphs) > 0 { assert.Contains(t, info.SubGraphs, "2") subInfo := info.SubGraphs["2"] assert.Equal(t, []string{"sub1"}, subInfo.RerunNodes) } else { assert.Equal(t, []string{"2"}, info.RerunNodes) } mu.Lock() assert.Equal(t, "test_main", receivedInput) mu.Unlock() result, err := r.Invoke(ctx, "", WithCheckPointID("cp1")) assert.NoError(t, err) assert.Equal(t, "test_main_sub_processed", result) mu.Lock() assert.Equal(t, "test_main", receivedInput) assert.Equal(t, 2, callCount) mu.Unlock() } type longRunningToolInput struct { Input string `json:"input"` } func TestToolsNodeWithExternalGraphInterrupt(t *testing.T) { store := newInMemoryStore() ctx := context.Background() var mu sync.Mutex var callCount int longRunningToolInfo := &schema.ToolInfo{ Name: "long_running_tool", Desc: "A tool that takes a long time to run", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "input": {Type: "string", Desc: "input"}, }), } longRunningTool := newCheckpointTestTool(longRunningToolInfo, func(ctx context.Context, in *longRunningToolInput) (string, error) { mu.Lock() callCount++ currentCount := callCount mu.Unlock() if currentCount == 1 { time.Sleep(2 * time.Second) } return "result_" + in.Input, nil }) toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{longRunningTool}, }) assert.NoError(t, err) g := NewGraph[*schema.Message, []*schema.Message]() err = g.AddToolsNode("tools", toolsNode) assert.NoError(t, err) err = g.AddEdge(START, "tools") assert.NoError(t, err) err = g.AddEdge("tools", END) assert.NoError(t, err) r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(store), ) assert.NoError(t, err) inputMsg := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{{ ID: "call_1", Type: "function", Function: schema.FunctionCall{ Name: "long_running_tool", Arguments: `{"input": "test"}`, }, }}, } canceledCtx, cancel := WithGraphInterrupt(ctx) go func() { time.Sleep(100 * time.Millisecond) cancel(WithGraphInterruptTimeout(0)) }() _, err = r.Invoke(canceledCtx, inputMsg, WithCheckPointID("cp1")) assert.Error(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok, "Expected interrupt error, got: %v", err) if ok { assert.Equal(t, []string{"tools"}, info.RerunNodes) } result, err := r.Invoke(ctx, &schema.Message{}, WithCheckPointID("cp1")) assert.NoError(t, err) assert.Len(t, result, 1) assert.Equal(t, `"result_test"`, result[0].Content) mu.Lock() assert.Equal(t, 2, callCount) mu.Unlock() } type checkpointTestTool[I, O any] struct { info *schema.ToolInfo fn func(ctx context.Context, in I) (O, error) } func newCheckpointTestTool[I, O any](info *schema.ToolInfo, f func(ctx context.Context, in I) (O, error)) tool.InvokableTool { return &checkpointTestTool[I, O]{ info: info, fn: f, } } func (f *checkpointTestTool[I, O]) Info(ctx context.Context) (*schema.ToolInfo, error) { return f.info, nil } func (f *checkpointTestTool[I, O]) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { t := generic.NewInstance[I]() err := sonic.UnmarshalString(argumentsInJSON, t) if err != nil { return "", err } o, err := f.fn(ctx, t) if err != nil { return "", err } return sonic.MarshalString(o) } ================================================ FILE: compose/component_to_graph_node.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" ) func toComponentNode[I, O, TOption any]( node any, componentType component, invoke Invoke[I, O, TOption], stream Stream[I, O, TOption], collect Collect[I, O, TOption], transform Transform[I, O, TOption], opts ...GraphAddNodeOpt, ) (*graphNode, *graphAddNodeOpts) { meta := parseExecutorInfoFromComponent(componentType, node) info, options := getNodeInfo(opts...) run := runnableLambda(invoke, stream, collect, transform, !meta.isComponentCallbackEnabled, ) gn := toNode(info, run, nil, meta, node, opts...) return gn, options } func toEmbeddingNode(node embedding.Embedder, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfEmbedding, node.EmbedStrings, nil, nil, nil, opts...) } func toRetrieverNode(node retriever.Retriever, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfRetriever, node.Retrieve, nil, nil, nil, opts...) } func toLoaderNode(node document.Loader, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfLoader, node.Load, nil, nil, nil, opts...) } func toIndexerNode(node indexer.Indexer, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfIndexer, node.Store, nil, nil, nil, opts...) } func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfChatModel, node.Generate, node.Stream, nil, nil, opts...) } func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfPrompt, node.Format, nil, nil, nil, opts...) } func toDocumentTransformerNode(node document.Transformer, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfTransformer, node.Transform, nil, nil, nil, opts...) } func toToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, ComponentOfToolsNode, node.Invoke, node.Stream, nil, nil, opts...) } func toLambdaNode(node *Lambda, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { info, options := getNodeInfo(opts...) gn := toNode(info, node.executor, nil, node.executor.meta, node, opts...) return gn, options } func toAnyGraphNode(node AnyGraph, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { meta := parseExecutorInfoFromComponent(node.component(), node) info, options := getNodeInfo(opts...) gn := toNode(info, nil, node, meta, node, opts...) return gn, options } func toPassthroughNode(opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { node := composablePassthrough() info, options := getNodeInfo(opts...) gn := toNode(info, node, nil, node.meta, node, opts...) return gn, options } func toNode(nodeInfo *nodeInfo, executor *composableRunnable, graph AnyGraph, meta *executorMeta, instance any, opts ...GraphAddNodeOpt) *graphNode { if meta == nil { meta = &executorMeta{} } gn := &graphNode{ nodeInfo: nodeInfo, cr: executor, g: graph, executorMeta: meta, instance: instance, opts: opts, } return gn } ================================================ FILE: compose/dag.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "fmt" ) func dagChannelBuilder(controlDependencies []string, dataDependencies []string, zeroValue func() any, emptyStream func() streamReader) channel { deps := make(map[string]dependencyState, len(controlDependencies)) for _, dep := range controlDependencies { deps[dep] = dependencyStateWaiting } indirect := make(map[string]bool, len(dataDependencies)) for _, dep := range dataDependencies { indirect[dep] = false } return &dagChannel{ Values: make(map[string]any), ControlPredecessors: deps, DataPredecessors: indirect, zeroValue: zeroValue, emptyStream: emptyStream, } } type dependencyState uint8 const ( dependencyStateWaiting dependencyState = iota dependencyStateReady dependencyStateSkipped ) type dagChannel struct { zeroValue func() any emptyStream func() streamReader ControlPredecessors map[string]dependencyState Values map[string]any DataPredecessors map[string]bool // if all dependencies have been skipped, indirect dependencies won't effect. Skipped bool mergeConfig FanInMergeConfig } func (ch *dagChannel) setMergeConfig(cfg FanInMergeConfig) { ch.mergeConfig.StreamMergeWithSourceEOF = cfg.StreamMergeWithSourceEOF } func (ch *dagChannel) load(c channel) error { dc, ok := c.(*dagChannel) if !ok { return fmt.Errorf("load dag channel fail, got %T, want *dagChannel", c) } ch.ControlPredecessors = dc.ControlPredecessors ch.DataPredecessors = dc.DataPredecessors ch.Skipped = dc.Skipped ch.Values = dc.Values return nil } func (ch *dagChannel) reportValues(ins map[string]any) error { if ch.Skipped { return nil } for k, v := range ins { if _, ok := ch.DataPredecessors[k]; !ok { continue } ch.DataPredecessors[k] = true ch.Values[k] = v } return nil } func (ch *dagChannel) reportDependencies(dependencies []string) { if ch.Skipped { return } for _, dep := range dependencies { if _, ok := ch.ControlPredecessors[dep]; ok { ch.ControlPredecessors[dep] = dependencyStateReady } } return } func (ch *dagChannel) reportSkip(keys []string) bool { for _, k := range keys { if _, ok := ch.ControlPredecessors[k]; ok { ch.ControlPredecessors[k] = dependencyStateSkipped } if _, ok := ch.DataPredecessors[k]; ok { ch.DataPredecessors[k] = true } } allSkipped := true for _, state := range ch.ControlPredecessors { if state != dependencyStateSkipped { allSkipped = false break } } ch.Skipped = allSkipped return allSkipped } func (ch *dagChannel) get(isStream bool, name string, edgeHandler *edgeHandlerManager) ( any, bool, error) { if ch.Skipped { return nil, false, nil } if len(ch.ControlPredecessors)+len(ch.DataPredecessors) == 0 { return nil, false, nil } for _, state := range ch.ControlPredecessors { if state == dependencyStateWaiting { return nil, false, nil } } for _, ready := range ch.DataPredecessors { if !ready { return nil, false, nil } } defer func() { ch.Values = make(map[string]any) for k := range ch.ControlPredecessors { ch.ControlPredecessors[k] = dependencyStateWaiting } for k := range ch.DataPredecessors { ch.DataPredecessors[k] = false } }() valueList := make([]any, len(ch.Values)) names := make([]string, len(ch.Values)) i := 0 for k, value := range ch.Values { resolvedV, err := edgeHandler.handle(k, name, value, isStream) if err != nil { return nil, false, err } valueList[i] = resolvedV names[i] = k i++ } if len(valueList) == 0 { if isStream { return ch.emptyStream(), true, nil } return ch.zeroValue(), true, nil } if len(valueList) == 1 { return valueList[0], true, nil } mergeOpts := &mergeOptions{ streamMergeWithSourceEOF: ch.mergeConfig.StreamMergeWithSourceEOF, names: names, } v, err := mergeValues(valueList, mergeOpts) if err != nil { return nil, false, err } return v, true, nil } func (ch *dagChannel) convertValues(fn func(map[string]any) error) error { return fn(ch.Values) } ================================================ FILE: compose/dag_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "io" "testing" ) func TestDAG(t *testing.T) { var err error g := NewGraph[string, string]() err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }), WithOutputKey("1")) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }), WithOutputKey("2")) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { if _, ok := input["1"]; !ok { return "", fmt.Errorf("node 1 output fail: %+v", input) } if _, ok := input["2"]; !ok { return "", fmt.Errorf("node 2 output fail: %+v", input) } return input["1"].(string) + input["2"].(string), nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("6", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }), WithOutputKey("6")) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { if _, ok := input["1"]; !ok { return "", fmt.Errorf("7:node 1 output fail: %+v", input) } if _, ok := input["6"]; !ok { return "", fmt.Errorf("7:node 6 output fail: %+v", input) } return input["1"].(string) + input["6"].(string), nil })) if err != nil { t.Fatal(err) } err = g.AddEdge("1", "3") if err != nil { t.Fatal(err) } err = g.AddEdge("2", "3") if err != nil { t.Fatal(err) } err = g.AddEdge("3", "4") if err != nil { t.Fatal(err) } err = g.AddEdge("4", "5") if err != nil { t.Fatal(err) } err = g.AddEdge("4", "6") if err != nil { t.Fatal(err) } err = g.AddEdge("6", "7") if err != nil { t.Fatal(err) } err = g.AddEdge("1", "7") if err != nil { t.Fatal(err) } err = g.AddEdge(START, "1") if err != nil { t.Fatal(err) } err = g.AddEdge(START, "2") if err != nil { t.Fatal(err) } err = g.AddEdge("7", END) if err != nil { t.Fatal(err) } runner, err := g.Compile(context.Background(), WithNodeTriggerMode(AllPredecessor)) if err != nil { t.Fatal(err) } // success ctx := context.Background() out, err := runner.Invoke(ctx, "hello") if err != nil { t.Fatal(err) } if out != "hellohellohello" { t.Fatalf("node7 fail") } result, err := runner.Invoke(ctx, "1") if err != nil { t.Fatal(err) } if result != "111" { t.Fatalf("runner invoke fail, output: %s", result) } streamResult, err := runner.Stream(ctx, "1") if err != nil { t.Fatal(err) } defer streamResult.Close() ret := "" for { chunk, err := streamResult.Recv() if err == io.EOF { break } if err != nil { t.Fatal(err) } ret += chunk } if ret != "111" { t.Fatalf("runner stream fail, output: %s", ret) } // loop gg := NewGraph[string, map[string]any]() err = gg.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }), WithOutputKey("1")) if err != nil { t.Fatal(err) } err = gg.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { return input["1"].(string), nil })) if err != nil { t.Fatal(err) } err = gg.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }), WithOutputKey("3")) if err != nil { t.Fatal(err) } err = gg.AddEdge("1", "2") if err != nil { t.Fatal(err) } err = gg.AddEdge("2", "3") if err != nil { t.Fatal(err) } err = gg.AddEdge("3", "2") if err != nil { t.Fatal(err) } err = gg.AddEdge(START, "1") if err != nil { t.Fatal(err) } err = gg.AddEdge("3", END) if err != nil { t.Fatal(err) } _, err = gg.compile(ctx, &graphCompileOptions{nodeTriggerMode: AllPredecessor}) if err == nil { t.Fatal("cannot validate loop") } } ================================================ FILE: compose/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package compose provides graph and workflow primitives to build // composable, interruptible execution pipelines with callback support. package compose ================================================ FILE: compose/error.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "errors" "fmt" "reflect" "strings" ) // ErrExceedMaxSteps graph will throw this error when the number of steps exceeds the maximum number of steps. var ErrExceedMaxSteps = errors.New("exceeds max steps") func newUnexpectedInputTypeErr(expected reflect.Type, got reflect.Type) error { return fmt.Errorf("unexpected input type. expected: %v, got: %v", expected, got) } type defaultImplAction string const ( actionInvokeByStream defaultImplAction = "InvokeByStream" actionInvokeByCollect defaultImplAction = "InvokeByCollect" actionInvokeByTransform defaultImplAction = "InvokeByTransform" actionStreamByInvoke defaultImplAction = "StreamByInvoke" actionStreamByTransform defaultImplAction = "StreamByTransform" actionStreamByCollect defaultImplAction = "StreamByCollect" actionCollectByTransform defaultImplAction = "CollectByTransform" actionCollectByInvoke defaultImplAction = "CollectByInvoke" actionCollectByStream defaultImplAction = "CollectByStream" actionTransformByStream defaultImplAction = "TransformByStream" actionTransformByCollect defaultImplAction = "TransformByCollect" actionTransformByInvoke defaultImplAction = "TransformByInvoke" ) func newStreamReadError(err error) error { return fmt.Errorf("failed to read from stream. error: %w", err) } func newGraphRunError(err error) error { return &internalError{ typ: internalErrorTypeGraphRun, nodePath: NodePath{}, origError: err, } } func wrapGraphNodeError(nodeKey string, err error) error { if ok := isInterruptError(err); ok { return err } var ie *internalError ok := errors.As(err, &ie) if !ok { return &internalError{ typ: internalErrorTypeNodeRun, nodePath: NodePath{path: []string{nodeKey}}, origError: err, } } ie.nodePath.path = append([]string{nodeKey}, ie.nodePath.path...) return ie } type internalErrorType string const ( internalErrorTypeNodeRun = "NodeRunError" internalErrorTypeGraphRun = "GraphRunError" ) type internalError struct { typ internalErrorType nodePath NodePath origError error } func (i *internalError) Error() string { sb := strings.Builder{} sb.WriteString(string("[" + i.typ + "] ")) sb.WriteString(i.origError.Error()) if len(i.nodePath.path) > 0 { sb.WriteString("\n------------------------\n") sb.WriteString("node path: [") for j := 0; j < len(i.nodePath.path)-1; j++ { sb.WriteString(i.nodePath.path[j] + ", ") } sb.WriteString(i.nodePath.path[len(i.nodePath.path)-1]) sb.WriteString("]") } sb.WriteString("") return sb.String() } func (i *internalError) Unwrap() error { return i.origError } ================================================ FILE: compose/error_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "testing" "time" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestCommonError(t *testing.T) { g := NewGraph[string, string]() assert.NoError(t, g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", errors.New("my error") }))) assert.NoError(t, g.AddEdge(START, "1")) assert.NoError(t, g.AddEdge("1", END)) ctx := context.Background() r, err := g.Compile(ctx) assert.NoError(t, err) // node error _, err = r.Invoke(ctx, "input") var ie *internalError assert.True(t, errors.As(err, &ie)) assert.Equal(t, "my error", ie.origError.Error()) // wrapper error sr, sw := schema.Pipe[string](0) sw.Close() _, err = r.Transform(ctx, sr) assert.True(t, errors.As(err, &ie)) assert.ErrorContains(t, ie.origError, "stream reader is empty, concat fail") assert.Equal(t, []string{"1"}, ie.nodePath.path) println(err.Error()) } func TestSubGraphNodeError(t *testing.T) { subG := NewGraph[string, string]() assert.NoError(t, subG.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", errors.New("my error") }))) assert.NoError(t, subG.AddEdge(START, "1")) assert.NoError(t, subG.AddEdge("1", END)) g := NewGraph[string, string]() assert.NoError(t, g.AddGraphNode("a", subG)) assert.NoError(t, g.AddEdge(START, "a")) assert.NoError(t, g.AddEdge("a", END)) ctx := context.Background() r, err := g.Compile(ctx) assert.NoError(t, err) _, err = r.Invoke(ctx, "input") var ie *internalError assert.True(t, errors.As(err, &ie)) assert.Equal(t, "my error", ie.origError.Error()) assert.Equal(t, []string{"a", "1"}, ie.nodePath.path) } func TestContextCancelDuringRun(t *testing.T) { // Create a graph with a long-running node to test context cancellation g := NewGraph[string, string]() // Add a node that waits for some time (long enough to be cancelled) assert.NoError(t, g.AddLambdaNode("slow_node", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { select { case <-ctx.Done(): // Return context's error when cancelled return "", ctx.Err() case <-time.After(200 * time.Millisecond): return input + "_processed", nil } }))) assert.NoError(t, g.AddEdge(START, "slow_node")) assert.NoError(t, g.AddEdge("slow_node", END)) // Create a context that we can cancel ctx, cancel := context.WithCancel(context.Background()) // Compile the graph r, err := g.Compile(ctx) assert.NoError(t, err) // Run the invoke in a goroutine resultCh := make(chan error) go func() { _, err := r.Invoke(ctx, "input") resultCh <- err }() // Cancel the context after a short delay time.Sleep(50 * time.Millisecond) cancel() // Get the result err = <-resultCh // Verify the error is related to context cancellation assert.Error(t, err) // Check error type and content var ie *internalError assert.True(t, errors.As(err, &ie)) // Error path should contain the node assert.Equal(t, []string{"slow_node"}, ie.nodePath.path) // Original error should be context.Canceled assert.ErrorIs(t, ie.origError, context.Canceled) // Test unwrap capability unwrappedErr := ie.Unwrap() assert.ErrorIs(t, unwrappedErr, context.Canceled) } ================================================ FILE: compose/field_mapping.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "errors" "fmt" "reflect" "runtime/debug" "strings" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) type FieldMapping struct { fromNodeKey string from string to string customExtractor func(input any) (any, error) } // String returns the string representation of the FieldMapping. func (m *FieldMapping) String() string { var sb strings.Builder sb.WriteString("[from ") if m.from != "" { sb.WriteString(m.from) sb.WriteString("(field) of ") } sb.WriteString(m.fromNodeKey) if m.to != "" { sb.WriteString(" to ") sb.WriteString(m.to) sb.WriteString("(field)") } sb.WriteString("]") return sb.String() } // FromField creates a FieldMapping that maps a single predecessor field to the entire successor input. // This is an exclusive mapping - once set, no other field mappings can be added since the successor input // has already been fully mapped. // Field: either the field of a struct, or the key of a map. func FromField(from string) *FieldMapping { return &FieldMapping{ from: from, } } // ToField creates a FieldMapping that maps the entire predecessor output to a single successor field. // Field: either the field of a struct, or the key of a map. func ToField(to string, opts ...FieldMappingOption) *FieldMapping { fm := &FieldMapping{ to: to, } for _, opt := range opts { opt(fm) } return fm } // MapFields creates a FieldMapping that maps a single predecessor field to a single successor field. // Field: either the field of a struct, or the key of a map. func MapFields(from, to string) *FieldMapping { return &FieldMapping{ from: from, to: to, } } func (m *FieldMapping) FromNodeKey() string { return m.fromNodeKey } func (m *FieldMapping) FromPath() FieldPath { return splitFieldPath(m.from) } func (m *FieldMapping) ToPath() FieldPath { return splitFieldPath(m.to) } func (m *FieldMapping) Equals(o *FieldMapping) bool { if m == nil { return o == nil } if o == nil || m.customExtractor != nil || o.customExtractor != nil { return false } return m.from == o.from && m.to == o.to && m.fromNodeKey == o.fromNodeKey } // FieldPath represents a path to a nested field in a struct or map. // Each element in the path is either: // - a struct field name // - a map key // // Example paths: // - []string{"user"} // top-level field // - []string{"user", "name"} // nested struct field // - []string{"users", "admin"} // map key access type FieldPath []string func (fp *FieldPath) join() string { return strings.Join(*fp, pathSeparator) } func splitFieldPath(path string) FieldPath { p := strings.Split(path, pathSeparator) if len(p) == 1 && p[0] == "" { return FieldPath{} } return p } // pathSeparator is a special character (Unit Separator) used internally to join path elements. // This character is chosen because it's extremely unlikely to appear in user-defined field names or map keys. const pathSeparator = "\x1F" // FromFieldPath creates a FieldMapping that maps a single predecessor field path to the entire successor input. // This is an exclusive mapping - once set, no other field mappings can be added since the successor input // has already been fully mapped. // // Example: // // // Maps the 'name' field from nested 'user.profile' to the entire successor input // FromFieldPath(FieldPath{"user", "profile", "name"}) // // Note: The field path elements must not contain the internal path separator character ('\x1F'). func FromFieldPath(fromFieldPath FieldPath) *FieldMapping { return &FieldMapping{ from: fromFieldPath.join(), } } // ToFieldPath creates a FieldMapping that maps the entire predecessor output to a single successor field path. // // Example: // // // Maps the entire predecessor output to response.data.userName // ToFieldPath(FieldPath{"response", "data", "userName"}) // // Note: The field path elements must not contain the internal path separator character ('\x1F'). func ToFieldPath(toFieldPath FieldPath, opts ...FieldMappingOption) *FieldMapping { fm := &FieldMapping{ to: toFieldPath.join(), } for _, opt := range opts { opt(fm) } return fm } // MapFieldPaths creates a FieldMapping that maps a single predecessor field path to a single successor field path. // // Example: // // // Maps user.profile.name to response.userName // MapFieldPaths( // FieldPath{"user", "profile", "name"}, // FieldPath{"response", "userName"}, // ) // // Note: The field path elements must not contain the internal path separator character ('\x1F'). func MapFieldPaths(fromFieldPath, toFieldPath FieldPath) *FieldMapping { return &FieldMapping{ from: fromFieldPath.join(), to: toFieldPath.join(), } } // FieldMappingOption is a functional option for configuring a FieldMapping. type FieldMappingOption func(*FieldMapping) // WithCustomExtractor sets a custom extractor function for the FieldMapping. // The extractor function is used to extract a value from the 'source' of the FieldMapping. // NOTE: if specified in this way, Eino can only check the validity of the field mapping at request time.. func WithCustomExtractor(extractor func(input any) (any, error)) FieldMappingOption { return func(m *FieldMapping) { m.customExtractor = extractor } } func (m *FieldMapping) targetPath() FieldPath { return splitFieldPath(m.to) } func buildFieldMappingConverter[I any]() func(input any) (any, error) { return func(input any) (any, error) { in, ok := input.(map[string]any) if !ok { panic(newUnexpectedInputTypeErr(reflect.TypeOf(map[string]any{}), reflect.TypeOf(input))) } return convertTo(in, generic.TypeOf[I]()), nil } } func buildStreamFieldMappingConverter[I any]() func(input streamReader) streamReader { return func(input streamReader) streamReader { s, ok := unpackStreamReader[map[string]any](input) if !ok { panic("mappingStreamAssign incoming streamReader chunk type not map[string]any") } return packStreamReader(schema.StreamReaderWithConvert(s, func(v map[string]any) (I, error) { t := convertTo(v, generic.TypeOf[I]()) return t.(I), nil })) } } func convertTo(mappings map[string]any, typ reflect.Type) any { tValue := newInstanceByType(typ) if !tValue.CanAddr() { tValue = newInstanceByType(reflect.PointerTo(typ)).Elem() } for mapping, taken := range mappings { tValue = assignOne(tValue, taken, mapping) } return tValue.Interface() } func assignOne(destValue reflect.Value, taken any, to string) reflect.Value { if len(to) == 0 { // assign to output directly destValue.Set(reflect.ValueOf(taken)) return destValue } var ( toPaths = splitFieldPath(to) originalDestValue = destValue parentMap reflect.Value parentKey string ) for { path := toPaths[0] toPaths = toPaths[1:] if len(toPaths) == 0 { toSet := reflect.ValueOf(taken) if destValue.Type() == reflect.TypeOf((*any)(nil)).Elem() { existingMap, ok := destValue.Interface().(map[string]any) if ok { destValue = reflect.ValueOf(existingMap) } else { mapValue := reflect.MakeMap(reflect.TypeOf(map[string]any{})) destValue.Set(mapValue) destValue = mapValue } } if destValue.Kind() == reflect.Map { key := reflect.ValueOf(path) keyType := destValue.Type().Key() if keyType != strType { key = key.Convert(keyType) } if !toSet.IsValid() { toSet = reflect.Zero(destValue.Type().Elem()) } destValue.SetMapIndex(key, toSet) if parentMap.IsValid() { parentMap.SetMapIndex(reflect.ValueOf(parentKey), destValue) } return originalDestValue } ptrValue := destValue for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } if !toSet.IsValid() { // just skip it, because this 'nil' is the zero value of the corresponding struct field } else { field := destValue.FieldByName(path) field.Set(toSet) } if parentMap.IsValid() { parentMap.SetMapIndex(reflect.ValueOf(parentKey), ptrValue) } return originalDestValue } if destValue.Type() == reflect.TypeOf((*any)(nil)).Elem() { existingMap, ok := destValue.Interface().(map[string]any) if ok { destValue = reflect.ValueOf(existingMap) } else { mapValue := reflect.MakeMap(reflect.TypeOf(map[string]any{})) destValue.Set(mapValue) destValue = mapValue } } if destValue.Kind() == reflect.Map { keyValue := reflect.ValueOf(path) valueValue := destValue.MapIndex(keyValue) if !valueValue.IsValid() { valueValue = newInstanceByType(destValue.Type().Elem()) destValue.SetMapIndex(keyValue, valueValue) } if parentMap.IsValid() { parentMap.SetMapIndex(reflect.ValueOf(parentKey), destValue) } parentMap = destValue parentKey = path destValue = valueValue continue } ptrValue := destValue for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } field := destValue.FieldByName(path) instantiateIfNeeded(field) if parentMap.IsValid() { parentMap.SetMapIndex(reflect.ValueOf(parentKey), ptrValue) parentMap = reflect.Value{} parentKey = "" } destValue = field } } func instantiateIfNeeded(field reflect.Value) { if field.Kind() == reflect.Ptr { if field.IsNil() { field.Set(reflect.New(field.Type().Elem())) } } else if field.Kind() == reflect.Map { if field.IsNil() { field.Set(reflect.MakeMap(field.Type())) } } } func newInstanceByType(typ reflect.Type) reflect.Value { switch typ.Kind() { case reflect.Map: return reflect.MakeMap(typ) case reflect.Slice, reflect.Array: slice := reflect.New(typ).Elem() slice.Set(reflect.MakeSlice(typ, 0, 0)) return slice case reflect.Ptr: typ = typ.Elem() origin := reflect.New(typ) nested := newInstanceByType(typ) origin.Elem().Set(nested) return origin default: return reflect.New(typ).Elem() } } func checkAndExtractFromField(fromField string, input reflect.Value) (reflect.Value, error) { f := input.FieldByName(fromField) if !f.IsValid() { return reflect.Value{}, fmt.Errorf("field mapping from a struct field, but field not found. field=%v, inputType=%v", fromField, input.Type()) } if !f.CanInterface() { return reflect.Value{}, fmt.Errorf("field mapping from a struct field, but field not exported. field= %v, inputType=%v", fromField, input.Type()) } return f, nil } type errMapKeyNotFound struct { mapKey string } func (e *errMapKeyNotFound) Error() string { return fmt.Sprintf("key=%s", e.mapKey) } type errInterfaceNotValidForFieldMapping struct { interfaceType reflect.Type actualType reflect.Type } func (e *errInterfaceNotValidForFieldMapping) Error() string { return fmt.Sprintf("field mapping from an interface type, but actual type is not struct, struct ptr or map. InterfaceType= %v, ActualType= %v", e.interfaceType, e.actualType) } func checkAndExtractFromMapKey(fromMapKey string, input reflect.Value) (reflect.Value, error) { key := reflect.ValueOf(fromMapKey) if input.Type().Key() != strType { key = key.Convert(input.Type().Key()) } v := input.MapIndex(key) if !v.IsValid() { return reflect.Value{}, fmt.Errorf("field mapping from a map key, but key not found in input. %w", &errMapKeyNotFound{mapKey: fromMapKey}) } return v, nil } func checkAndExtractFieldType(paths []string, typ reflect.Type) (extracted reflect.Type, remainingPaths FieldPath, err error) { extracted = typ for i, field := range paths { for extracted.Kind() == reflect.Ptr { extracted = extracted.Elem() } if extracted.Kind() == reflect.Map { if !strType.ConvertibleTo(extracted.Key()) { return nil, nil, fmt.Errorf("type[%v] is not a map with string or string alias key", extracted) } extracted = extracted.Elem() continue } if extracted.Kind() == reflect.Struct { f, ok := extracted.FieldByName(field) if !ok { return nil, nil, fmt.Errorf("type[%v] has no field[%s]", extracted, field) } if !f.IsExported() { return nil, nil, fmt.Errorf("type[%v] has an unexported field[%s]", extracted.String(), field) } extracted = f.Type continue } if extracted.Kind() == reflect.Interface { return extracted, paths[i:], nil } return nil, nil, fmt.Errorf("intermediate type[%v] is not valid", extracted) } return extracted, nil, nil } var strType = reflect.TypeOf("") func fieldMap(mappings []*FieldMapping, allowMapKeyNotFound bool, uncheckedSourcePaths map[string]FieldPath) func(any) (map[string]any, error) { return func(input any) (result map[string]any, err error) { result = make(map[string]any, len(mappings)) var inputValue reflect.Value loop: for _, mapping := range mappings { if mapping.customExtractor != nil { result[mapping.to], err = mapping.customExtractor(input) if err != nil { return nil, err } continue } if len(mapping.from) == 0 { result[mapping.to] = input continue } fromPath := splitFieldPath(mapping.from) if !inputValue.IsValid() { inputValue = reflect.ValueOf(input) } var ( pathInputValue = inputValue pathInputType = inputValue.Type() taken = input ) for i, path := range fromPath { for pathInputValue.Kind() == reflect.Ptr { pathInputValue = pathInputValue.Elem() } if !pathInputValue.IsValid() { return nil, fmt.Errorf("intermediate source value on path=%v is nil for type [%v]", fromPath[:i+1], pathInputType) } if pathInputValue.Kind() == reflect.Map && pathInputValue.IsNil() { return nil, fmt.Errorf("intermediate source value on path=%v is nil for map type [%v]", fromPath[:i+1], pathInputType) } taken, pathInputType, err = takeOne(pathInputValue, pathInputType, path) if err != nil { // we deferred check from Compile time to request time for interface types, so we won't panic here var interfaceNotValidErr *errInterfaceNotValidForFieldMapping if errors.As(err, &interfaceNotValidErr) { return nil, err } // map key not found can only be a request time error, so we won't panic here var mapKeyNotFoundErr *errMapKeyNotFound if errors.As(err, &mapKeyNotFoundErr) { if allowMapKeyNotFound { continue loop } return nil, err } if uncheckedSourcePaths != nil { uncheckedPath, ok := uncheckedSourcePaths[mapping.from] if ok && len(uncheckedPath) >= len(fromPath)-i { // the err happens on the mapping source path which is unchecked at request time, so we won't panic here return nil, err } } panic(safe.NewPanicErr(err, debug.Stack())) } if i < len(fromPath)-1 { pathInputValue = reflect.ValueOf(taken) } } result[mapping.to] = taken } return result, nil } } func streamFieldMap(mappings []*FieldMapping, uncheckedSourcePaths map[string]FieldPath) func(streamReader) streamReader { return func(input streamReader) streamReader { return packStreamReader(schema.StreamReaderWithConvert(input.toAnyStreamReader(), fieldMap(mappings, true, uncheckedSourcePaths))) } } func takeOne(inputValue reflect.Value, inputType reflect.Type, from string) (taken any, takenType reflect.Type, err error) { var f reflect.Value switch k := inputValue.Kind(); k { case reflect.Map: f, err = checkAndExtractFromMapKey(from, inputValue) if err != nil { return nil, nil, err } return f.Interface(), f.Type(), nil case reflect.Struct: f, err = checkAndExtractFromField(from, inputValue) if err != nil { return nil, nil, err } return f.Interface(), f.Type(), nil default: if inputType.Kind() == reflect.Interface { return nil, nil, &errInterfaceNotValidForFieldMapping{ interfaceType: inputType, actualType: inputValue.Type(), } } panic("when take one value from source, value not map or struct, and type not interface") } } func isFromAll(mappings []*FieldMapping) bool { for _, mapping := range mappings { if len(mapping.from) == 0 && mapping.customExtractor == nil { return true } } return false } func fromFields(mappings []*FieldMapping) bool { for _, mapping := range mappings { if len(mapping.from) == 0 || mapping.customExtractor != nil { return false } } return true } func isToAll(mappings []*FieldMapping) bool { for _, mapping := range mappings { if len(mapping.to) == 0 { return true } } return false } func validateStructOrMap(t reflect.Type) bool { switch t.Kind() { case reflect.Map: return true case reflect.Ptr: t = t.Elem() fallthrough case reflect.Struct: return true default: return false } } func validateFieldMapping(predecessorType reflect.Type, successorType reflect.Type, mappings []*FieldMapping) ( // type checkers that are deferred to request-time typeHandler *handlerPair, // the remaining predecessor field paths that are not checked at compile time because of interface type found uncheckedSourcePath map[string]FieldPath, err error) { // check if mapping is legal if isFromAll(mappings) && isToAll(mappings) { // unreachable panic(fmt.Errorf("invalid field mappings: from all fields to all, use common edge instead")) } else if !isToAll(mappings) && (!validateStructOrMap(successorType) && successorType != reflect.TypeOf((*any)(nil)).Elem()) { // if user has not provided a specific struct type, graph cannot construct any struct in the runtime return nil, nil, fmt.Errorf("static check fail: successor input type should be struct or map, actual: %v", successorType) } else if fromFields(mappings) && !validateStructOrMap(predecessorType) { return nil, nil, fmt.Errorf("static check fail: predecessor output type should be struct or map, actual: %v", predecessorType) } var fieldCheckers map[string]handlerPair for i := range mappings { mapping := mappings[i] successorFieldType, successorRemaining, err := checkAndExtractFieldType(splitFieldPath(mapping.to), successorType) if err != nil { return nil, nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err) } if len(successorRemaining) > 0 { if successorFieldType == reflect.TypeOf((*any)(nil)).Elem() { continue // at request time expand this 'any' to 'map[string]any' } return nil, nil, fmt.Errorf("static check failed for mapping %s, the successor has intermediate interface type %v", mapping, successorFieldType) } if mapping.customExtractor != nil { // custom extractor applies to request-time data, so skip compile-time check continue } predecessorFieldType, predecessorRemaining, err := checkAndExtractFieldType(splitFieldPath(mapping.from), predecessorType) if err != nil { return nil, nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err) } if len(predecessorRemaining) > 0 { if uncheckedSourcePath == nil { uncheckedSourcePath = make(map[string]FieldPath) } uncheckedSourcePath[mapping.from] = predecessorRemaining } checker := func(a any) (any, error) { trueInType := reflect.TypeOf(a) if trueInType == nil { switch successorFieldType.Kind() { case reflect.Map, reflect.Slice, reflect.Ptr, reflect.Interface: default: return nil, fmt.Errorf("runtime check failed for mapping %s, field[%v]-[%v] is absolutely not assignable", mapping, trueInType, successorFieldType) } } else { if !trueInType.AssignableTo(successorFieldType) { return nil, fmt.Errorf("runtime check failed for mapping %s, field[%v]-[%v] is absolutely not assignable", mapping, trueInType, successorFieldType) } } return a, nil } if len(predecessorRemaining) > 0 { // can't check if types match at compile time, because there is interface type at some point along the source path. Defer to request time if fieldCheckers == nil { fieldCheckers = make(map[string]handlerPair) } fieldCheckers[mapping.to] = handlerPair{ invoke: checker, transform: func(input streamReader) streamReader { return packStreamReader(schema.StreamReaderWithConvert(input.toAnyStreamReader(), checker)) }, } } else { at := checkAssignable(predecessorFieldType, successorFieldType) if at == assignableTypeMustNot { return nil, nil, fmt.Errorf("static check failed for mapping %s, field[%v]-[%v] is absolutely not assignable", mapping, predecessorFieldType, successorFieldType) } else if at == assignableTypeMay { // can't decide if types match, because the successorFieldType implements predecessorFieldType, which is an interface type if fieldCheckers == nil { fieldCheckers = make(map[string]handlerPair) } fieldCheckers[mapping.to] = handlerPair{ invoke: checker, transform: func(input streamReader) streamReader { return packStreamReader(schema.StreamReaderWithConvert(input.toAnyStreamReader(), checker)) }, } } } } if len(fieldCheckers) == 0 { return nil, uncheckedSourcePath, nil } checker := func(value map[string]any) (map[string]any, error) { var err error for k, v := range fieldCheckers { for mapping := range value { if mapping == k { value[mapping], err = v.invoke(value[mapping]) if err != nil { return nil, err } } } } return value, nil } return &handlerPair{ invoke: func(value any) (any, error) { return checker(value.(map[string]any)) }, transform: func(input streamReader) streamReader { s, ok := unpackStreamReader[map[string]any](input) if !ok { // impossible panic("field mapping edge stream value isn't map[string]any") } return packStreamReader(schema.StreamReaderWithConvert(s, checker)) }, }, uncheckedSourcePath, nil } ================================================ FILE: compose/generic_graph.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "reflect" "github.com/cloudwego/eino/internal/generic" ) type newGraphOptions struct { withState func(ctx context.Context) any stateType reflect.Type } // NewGraphOption configures behavior when creating a new graph, such as // providing local state generation. type NewGraphOption func(ngo *newGraphOptions) // WithGenLocalState registers a function to generate per-run local state // that can be shared across nodes in the graph. func WithGenLocalState[S any](gls GenLocalState[S]) NewGraphOption { return func(ngo *newGraphOptions) { ngo.withState = func(ctx context.Context) any { return gls(ctx) } ngo.stateType = generic.TypeOf[S]() } } // NewGraph create a directed graph that can compose components, lambda, chain, parallel etc. // simultaneously provide flexible and multi-granular aspect governance capabilities. // I: the input type of graph compiled product // O: the output type of graph compiled product // // To share state between nodes, use WithGenLocalState option: // // type testState struct { // UserInfo *UserInfo // KVs map[string]any // } // // genStateFunc := func(ctx context.Context) *testState { // return &testState{} // } // // graph := compose.NewGraph[string, string](WithGenLocalState(genStateFunc)) // // // you can use WithStatePreHandler and WithStatePostHandler to do something with state // graph.AddNode("node1", someNode, compose.WithPreHandler(func(ctx context.Context, in string, state *testState) (string, error) { // // do something with state // return in, nil // }), compose.WithPostHandler(func(ctx context.Context, out string, state *testState) (string, error) { // // do something with state // return out, nil // })) func NewGraph[I, O any](opts ...NewGraphOption) *Graph[I, O] { options := &newGraphOptions{} for _, opt := range opts { opt(options) } g := &Graph[I, O]{ newGraphFromGeneric[I, O]( ComponentOfGraph, options.withState, options.stateType, opts, ), } return g } // Graph is a generic graph that can be used to compose components. // I: the input type of graph compiled product // O: the output type of graph compiled product type Graph[I, O any] struct { *graph } // AddEdge adds an edge to the graph, edge means a data flow from startNode to endNode. // the previous node's output type must be set to the next node's input type. // NOTE: startNode and endNode must have been added to the graph before adding edge. // e.g. // // graph.AddNode("start_node_key", compose.NewPassthroughNode()) // graph.AddNode("end_node_key", compose.NewPassthroughNode()) // // err := graph.AddEdge("start_node_key", "end_node_key") func (g *Graph[I, O]) AddEdge(startNode, endNode string) (err error) { return g.graph.addEdgeWithMappings(startNode, endNode, false, false) } // Compile take the raw graph and compile it into a form ready to be run. // e.g. // // graph, err := compose.NewGraph[string, string]() // if err != nil {...} // // runnable, err := graph.Compile(ctx, compose.WithGraphName("my_graph")) // if err != nil {...} // // runnable.Invoke(ctx, "input") // invoke // runnable.Stream(ctx, "input") // stream // runnable.Collect(ctx, inputReader) // collect // runnable.Transform(ctx, inputReader) // transform func (g *Graph[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { return compileAnyGraph[I, O](ctx, g, opts...) } func compileAnyGraph[I, O any](ctx context.Context, g AnyGraph, opts ...GraphCompileOption) (Runnable[I, O], error) { if len(globalGraphCompileCallbacks) > 0 { opts = append([]GraphCompileOption{WithGraphCompileCallbacks(globalGraphCompileCallbacks...)}, opts...) } option := newGraphCompileOptions(opts...) cr, err := g.compile(ctx, option) if err != nil { return nil, err } cr.meta = &executorMeta{ component: g.component(), isComponentCallbackEnabled: true, componentImplType: "", } cr.nodeInfo = &nodeInfo{ name: option.graphName, } ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { return initGraphCallbacks(AppendAddressSegment(ctx, AddressSegmentRunnable, option.graphName), cr.nodeInfo, cr.meta, opts...) } rp, err := toGenericRunnable[I, O](cr, ctxWrapper) if err != nil { return nil, err } return rp, nil } ================================================ FILE: compose/generic_helper.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "errors" "fmt" "reflect" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) func newGenericHelper[I, O any]() *genericHelper { return &genericHelper{ inputStreamFilter: defaultStreamMapFilter[I], outputStreamFilter: defaultStreamMapFilter[O], inputConverter: handlerPair{ invoke: defaultValueChecker[I], transform: defaultStreamConverter[I], }, outputConverter: handlerPair{ invoke: defaultValueChecker[O], transform: defaultStreamConverter[O], }, inputFieldMappingConverter: handlerPair{ invoke: buildFieldMappingConverter[I](), transform: buildStreamFieldMappingConverter[I](), }, outputFieldMappingConverter: handlerPair{ invoke: buildFieldMappingConverter[O](), transform: buildStreamFieldMappingConverter[O](), }, inputStreamConvertPair: defaultStreamConvertPair[I](), outputStreamConvertPair: defaultStreamConvertPair[O](), inputZeroValue: zeroValueFromGeneric[I], outputZeroValue: zeroValueFromGeneric[O], inputEmptyStream: emptyStreamFromGeneric[I], outputEmptyStream: emptyStreamFromGeneric[O], } } type genericHelper struct { // when set input key, use this method to convert input from map[string]any to T inputStreamFilter, outputStreamFilter streamMapFilter // when predecessor's output is assignableTypeMay to current node's input, validate and convert(if needed) types using the following two methods inputConverter, outputConverter handlerPair // when current node enable field mapping, convert map input to expected struct using the following two methods inputFieldMappingConverter, outputFieldMappingConverter handlerPair // can convert input/output from stream to non-stream or non-stream to stream, used for checkpoint inputStreamConvertPair, outputStreamConvertPair streamConvertPair inputZeroValue, outputZeroValue func() any inputEmptyStream, outputEmptyStream func() streamReader } func (g *genericHelper) forMapInput() *genericHelper { return &genericHelper{ outputStreamFilter: g.outputStreamFilter, outputConverter: g.outputConverter, outputFieldMappingConverter: g.outputFieldMappingConverter, outputStreamConvertPair: g.outputStreamConvertPair, outputZeroValue: g.outputZeroValue, outputEmptyStream: g.outputEmptyStream, inputStreamFilter: defaultStreamMapFilter[map[string]any], inputConverter: handlerPair{ invoke: defaultValueChecker[map[string]any], transform: defaultStreamConverter[map[string]any], }, inputFieldMappingConverter: handlerPair{ invoke: buildFieldMappingConverter[map[string]any](), transform: buildStreamFieldMappingConverter[map[string]any](), }, inputStreamConvertPair: defaultStreamConvertPair[map[string]any](), inputZeroValue: zeroValueFromGeneric[map[string]any], inputEmptyStream: emptyStreamFromGeneric[map[string]any], } } func (g *genericHelper) forMapOutput() *genericHelper { return &genericHelper{ inputStreamFilter: g.inputStreamFilter, inputConverter: g.inputConverter, inputFieldMappingConverter: g.inputFieldMappingConverter, inputStreamConvertPair: g.inputStreamConvertPair, inputZeroValue: g.inputZeroValue, inputEmptyStream: g.inputEmptyStream, outputStreamFilter: defaultStreamMapFilter[map[string]any], outputConverter: handlerPair{ invoke: defaultValueChecker[map[string]any], transform: defaultStreamConverter[map[string]any], }, outputFieldMappingConverter: handlerPair{ invoke: buildFieldMappingConverter[map[string]any](), transform: buildStreamFieldMappingConverter[map[string]any](), }, outputStreamConvertPair: defaultStreamConvertPair[map[string]any](), outputZeroValue: zeroValueFromGeneric[map[string]any], outputEmptyStream: emptyStreamFromGeneric[map[string]any], } } func (g *genericHelper) forPredecessorPassthrough() *genericHelper { return &genericHelper{ inputStreamFilter: g.inputStreamFilter, outputStreamFilter: g.inputStreamFilter, inputConverter: g.inputConverter, outputConverter: g.inputConverter, inputFieldMappingConverter: g.inputFieldMappingConverter, outputFieldMappingConverter: g.inputFieldMappingConverter, inputStreamConvertPair: g.inputStreamConvertPair, outputStreamConvertPair: g.inputStreamConvertPair, inputZeroValue: g.inputZeroValue, outputZeroValue: g.inputZeroValue, inputEmptyStream: g.inputEmptyStream, outputEmptyStream: g.inputEmptyStream, } } func (g *genericHelper) forSuccessorPassthrough() *genericHelper { return &genericHelper{ inputStreamFilter: g.outputStreamFilter, outputStreamFilter: g.outputStreamFilter, inputConverter: g.outputConverter, outputConverter: g.outputConverter, inputFieldMappingConverter: g.outputFieldMappingConverter, outputFieldMappingConverter: g.outputFieldMappingConverter, inputStreamConvertPair: g.outputStreamConvertPair, outputStreamConvertPair: g.outputStreamConvertPair, inputZeroValue: g.outputZeroValue, outputZeroValue: g.outputZeroValue, inputEmptyStream: g.outputEmptyStream, outputEmptyStream: g.outputEmptyStream, } } type streamMapFilter func(key string, isr streamReader) (streamReader, bool) type valueHandler func(value any) (any, error) type streamHandler func(streamReader) streamReader type handlerPair struct { invoke valueHandler transform streamHandler } type streamConvertPair struct { concatStream func(sr streamReader) (any, error) restoreStream func(any) (streamReader, error) } func defaultStreamConvertPair[T any]() streamConvertPair { var t T return streamConvertPair{ concatStream: func(sr streamReader) (any, error) { tsr, ok := unpackStreamReader[T](sr) if !ok { return nil, fmt.Errorf("cannot convert sr to streamReader[%T]", t) } value, err := concatStreamReader(tsr) if err != nil { if errors.Is(err, emptyStreamConcatErr) { return nil, nil } return nil, err } return value, nil }, restoreStream: func(a any) (streamReader, error) { if a == nil { return packStreamReader(schema.StreamReaderFromArray([]T{})), nil } value, ok := a.(T) if !ok { return nil, fmt.Errorf("cannot convert value[%T] to streamReader[%T]", a, t) } return packStreamReader(schema.StreamReaderFromArray([]T{value})), nil }, } } func defaultStreamMapFilter[T any](key string, isr streamReader) (streamReader, bool) { sr, ok := unpackStreamReader[map[string]any](isr) if !ok { return nil, false } cvt := func(m map[string]any) (T, error) { var t T v, ok_ := m[key] if !ok_ { return t, schema.ErrNoValue } vv, ok_ := v.(T) if !ok_ { return t, fmt.Errorf( "[defaultStreamMapFilter]fail, key[%s]'s value type[%s] isn't expected type[%s]", key, reflect.TypeOf(v).String(), generic.TypeOf[T]().String()) } return vv, nil } ret := schema.StreamReaderWithConvert[map[string]any, T](sr, cvt) return packStreamReader(ret), true } func defaultStreamConverter[T any](reader streamReader) streamReader { return packStreamReader(schema.StreamReaderWithConvert(reader.toAnyStreamReader(), func(v any) (T, error) { vv, ok := v.(T) if !ok { var t T return t, fmt.Errorf("runtime type check fail, expected type: %T, actual type: %T", t, v) } return vv, nil })) } func defaultValueChecker[T any](v any) (any, error) { nValue, ok := v.(T) if !ok { var t T return nil, fmt.Errorf("runtime type check fail, expected type: %T, actual type: %T", t, v) } return nValue, nil } func zeroValueFromGeneric[T any]() any { var t T return t } func emptyStreamFromGeneric[T any]() streamReader { var t T sr, sw := schema.Pipe[T](1) sw.Send(t, nil) sw.Close() return packStreamReader(sr) } ================================================ FILE: compose/graph.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "reflect" "strings" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/gmap" ) // START is the start node of the graph. You can add your first edge with START. const START = "start" // END is the end node of the graph. You can add your last edge with END. const END = "end" // graphRunType is a custom type used to control the running mode of the graph. type graphRunType string const ( // runTypePregel is a running mode of the graph that is suitable for large-scale graph processing tasks. Can have cycles in graph. Compatible with NodeTriggerType.AnyPredecessor. runTypePregel graphRunType = "Pregel" // runTypeDAG is a running mode of the graph that represents the graph as a directed acyclic graph, suitable for tasks that can be represented as a directed acyclic graph. Compatible with NodeTriggerType.AllPredecessor. runTypeDAG graphRunType = "DAG" ) // String returns the string representation of the graph run type. func (g graphRunType) String() string { return string(g) } type graph struct { nodes map[string]*graphNode controlEdges map[string][]string dataEdges map[string][]string branches map[string][]*GraphBranch startNodes []string endNodes []string toValidateMap map[string][]struct { endNode string mappings []*FieldMapping } stateType reflect.Type stateGenerator func(ctx context.Context) any newOpts []NewGraphOption expectedInputType, expectedOutputType reflect.Type *genericHelper fieldMappingRecords map[string][]*FieldMapping buildError error cmp component compiled bool handlerOnEdges map[string]map[string][]handlerPair handlerPreNode map[string][]handlerPair handlerPreBranch map[string][][]handlerPair } type newGraphConfig struct { inputType, outputType reflect.Type gh *genericHelper cmp component stateType reflect.Type stateGenerator func(ctx context.Context) any newOpts []NewGraphOption } func newGraphFromGeneric[I, O any]( cmp component, stateGenerator func(ctx context.Context) any, stateType reflect.Type, opts []NewGraphOption, ) *graph { return newGraph(&newGraphConfig{ inputType: generic.TypeOf[I](), outputType: generic.TypeOf[O](), gh: newGenericHelper[I, O](), cmp: cmp, stateType: stateType, stateGenerator: stateGenerator, newOpts: opts, }) } func newGraph(cfg *newGraphConfig) *graph { return &graph{ nodes: make(map[string]*graphNode), dataEdges: make(map[string][]string), controlEdges: make(map[string][]string), branches: make(map[string][]*GraphBranch), toValidateMap: make(map[string][]struct { endNode string mappings []*FieldMapping }), expectedInputType: cfg.inputType, expectedOutputType: cfg.outputType, genericHelper: cfg.gh, fieldMappingRecords: make(map[string][]*FieldMapping), cmp: cfg.cmp, stateType: cfg.stateType, stateGenerator: cfg.stateGenerator, newOpts: cfg.newOpts, handlerOnEdges: make(map[string]map[string][]handlerPair), handlerPreNode: make(map[string][]handlerPair), handlerPreBranch: make(map[string][][]handlerPair), } } func (g *graph) component() component { return g.cmp } func isChain(cmp component) bool { return cmp == ComponentOfChain } func isWorkflow(cmp component) bool { return cmp == ComponentOfWorkflow } // ErrGraphCompiled is returned when attempting to modify a graph after it has been compiled var ErrGraphCompiled = errors.New("graph has been compiled, cannot be modified") func (g *graph) addNode(key string, node *graphNode, options *graphAddNodeOpts) (err error) { if g.buildError != nil { return g.buildError } if g.compiled { return ErrGraphCompiled } defer func() { if err != nil { g.buildError = err } }() if key == END || key == START { return fmt.Errorf("node '%s' is reserved, cannot add manually", key) } if _, ok := g.nodes[key]; ok { return fmt.Errorf("node '%s' already present", key) } // check options if options.needState { if g.stateGenerator == nil { return fmt.Errorf("node '%s' needs state but graph state is not enabled", key) } } if options.nodeOptions.nodeKey != "" { if !isChain(g.cmp) { return errors.New("only chain support node key option") } } // end: check options // check pre- / post-handler type if options.processor != nil { if options.processor.statePreHandler != nil { // check state type if g.stateType != options.processor.preStateType { return fmt.Errorf("node[%s]'s pre handler state type[%v] is different from graph[%v]", key, options.processor.preStateType, g.stateType) } // check input type if node.inputType() == nil && options.processor.statePreHandler.outputType != reflect.TypeOf((*any)(nil)).Elem() { return fmt.Errorf("passthrough node[%s]'s pre handler type isn't any", key) } else if node.inputType() != nil && node.inputType() != options.processor.statePreHandler.outputType { return fmt.Errorf("node[%s]'s pre handler type[%v] is different from its input type[%v]", key, options.processor.statePreHandler.outputType, node.inputType()) } } if options.processor.statePostHandler != nil { // check state type if g.stateType != options.processor.postStateType { return fmt.Errorf("node[%s]'s post handler state type[%v] is different from graph[%v]", key, options.processor.postStateType, g.stateType) } // check input type if node.outputType() == nil && options.processor.statePostHandler.inputType != reflect.TypeOf((*any)(nil)).Elem() { return fmt.Errorf("passthrough node[%s]'s post handler type isn't any", key) } else if node.outputType() != nil && node.outputType() != options.processor.statePostHandler.inputType { return fmt.Errorf("node[%s]'s post handler type[%v] is different from its output type[%v]", key, options.processor.statePostHandler.inputType, node.outputType()) } } } g.nodes[key] = node return nil } func (g *graph) addEdgeWithMappings(startNode, endNode string, noControl bool, noData bool, mappings ...*FieldMapping) (err error) { if g.buildError != nil { return g.buildError } if g.compiled { return ErrGraphCompiled } if noControl && noData { return fmt.Errorf("edge[%s]-[%s] cannot be both noDirectDependency and noDataFlow", startNode, endNode) } defer func() { if err != nil { g.buildError = err } }() if startNode == END { return errors.New("END cannot be a start node") } if endNode == START { return errors.New("START cannot be an end node") } if _, ok := g.nodes[startNode]; !ok && startNode != START { return fmt.Errorf("edge start node '%s' needs to be added to graph first", startNode) } if _, ok := g.nodes[endNode]; !ok && endNode != END { return fmt.Errorf("edge end node '%s' needs to be added to graph first", endNode) } if !noControl { for i := range g.controlEdges[startNode] { if g.controlEdges[startNode][i] == endNode { return fmt.Errorf("control edge[%s]-[%s] have been added yet", startNode, endNode) } } g.controlEdges[startNode] = append(g.controlEdges[startNode], endNode) if startNode == START { g.startNodes = append(g.startNodes, endNode) } if endNode == END { g.endNodes = append(g.endNodes, startNode) } } if !noData { for i := range g.dataEdges[startNode] { if g.dataEdges[startNode][i] == endNode { return fmt.Errorf("data edge[%s]-[%s] have been added yet", startNode, endNode) } } g.addToValidateMap(startNode, endNode, mappings) err = g.updateToValidateMap() if err != nil { return err } g.dataEdges[startNode] = append(g.dataEdges[startNode], endNode) } return nil } // AddEmbeddingNode adds a node that implements embedding.Embedder. // e.g. // // embeddingNode, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{ // Model: "text-embedding-3-small", // }) // // graph.AddEmbeddingNode("embedding_node_key", embeddingNode) func (g *graph) AddEmbeddingNode(key string, node embedding.Embedder, opts ...GraphAddNodeOpt) error { gNode, options := toEmbeddingNode(node, opts...) return g.addNode(key, gNode, options) } // AddRetrieverNode adds a node that implements retriever.Retriever. // e.g. // // retriever, err := vikingdb.NewRetriever(ctx, &vikingdb.RetrieverConfig{}) // // graph.AddRetrieverNode("retriever_node_key", retrieverNode) func (g *graph) AddRetrieverNode(key string, node retriever.Retriever, opts ...GraphAddNodeOpt) error { gNode, options := toRetrieverNode(node, opts...) return g.addNode(key, gNode, options) } // AddLoaderNode adds a node that implements document.Loader. // e.g. // // loader, err := file.NewLoader(ctx, &file.LoaderConfig{}) // // graph.AddLoaderNode("loader_node_key", loader) func (g *graph) AddLoaderNode(key string, node document.Loader, opts ...GraphAddNodeOpt) error { gNode, options := toLoaderNode(node, opts...) return g.addNode(key, gNode, options) } // AddIndexerNode adds a node that implements indexer.Indexer. // e.g. // // indexer, err := vikingdb.NewIndexer(ctx, &vikingdb.IndexerConfig{}) // // graph.AddIndexerNode("indexer_node_key", indexer) func (g *graph) AddIndexerNode(key string, node indexer.Indexer, opts ...GraphAddNodeOpt) error { gNode, options := toIndexerNode(node, opts...) return g.addNode(key, gNode, options) } // AddChatModelNode add node that implements model.BaseChatModel. // e.g. // // chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ // Model: "gpt-4o", // }) // // graph.AddChatModelNode("chat_model_node_key", chatModel) func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...GraphAddNodeOpt) error { gNode, options := toChatModelNode(node, opts...) return g.addNode(key, gNode, options) } // AddChatTemplateNode add node that implements prompt.ChatTemplate. // e.g. // // chatTemplate, err := prompt.FromMessages(schema.FString, &schema.Message{ // Role: schema.System, // Content: "You are acting as a {role}.", // }) // // graph.AddChatTemplateNode("chat_template_node_key", chatTemplate) func (g *graph) AddChatTemplateNode(key string, node prompt.ChatTemplate, opts ...GraphAddNodeOpt) error { gNode, options := toChatTemplateNode(node, opts...) return g.addNode(key, gNode, options) } // AddToolsNode adds a node that implements tools.ToolsNode. // e.g. // // toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{}) // // graph.AddToolsNode("tools_node_key", toolsNode) func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) error { gNode, options := toToolsNode(node, opts...) return g.addNode(key, gNode, options) } // AddDocumentTransformerNode adds a node that implements document.Transformer. // e.g. // // markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) // // graph.AddDocumentTransformerNode("document_transformer_node_key", markdownSplitter) func (g *graph) AddDocumentTransformerNode(key string, node document.Transformer, opts ...GraphAddNodeOpt) error { gNode, options := toDocumentTransformerNode(node, opts...) return g.addNode(key, gNode, options) } // AddLambdaNode add node that implements at least one of Invoke[I, O], Stream[I, O], Collect[I, O], Transform[I, O]. // due to the lack of supporting method generics, we need to use function generics to generate Lambda run as Runnable[I, O]. // for Invoke[I, O], use compose.InvokableLambda() // for Stream[I, O], use compose.StreamableLambda() // for Collect[I, O], use compose.CollectableLambda() // for Transform[I, O], use compose.TransformableLambda() // for arbitrary combinations of 4 kinds of lambda, use compose.AnyLambda() func (g *graph) AddLambdaNode(key string, node *Lambda, opts ...GraphAddNodeOpt) error { gNode, options := toLambdaNode(node, opts...) return g.addNode(key, gNode, options) } // AddGraphNode add one kind of Graph[I, O]、Chain[I, O]、StateChain[I, O, S] as a node. // for Graph[I, O], comes from NewGraph[I, O]() // for Chain[I, O], comes from NewChain[I, O]() func (g *graph) AddGraphNode(key string, node AnyGraph, opts ...GraphAddNodeOpt) error { gNode, options := toAnyGraphNode(node, opts...) return g.addNode(key, gNode, options) } // AddPassthroughNode adds a passthrough node to the graph. // mostly used in pregel mode of graph. // e.g. // // graph.AddPassthroughNode("passthrough_node_key") func (g *graph) AddPassthroughNode(key string, opts ...GraphAddNodeOpt) error { gNode, options := toPassthroughNode(opts...) return g.addNode(key, gNode, options) } // AddBranch adds a branch to the graph. // e.g. // // condition := func(ctx context.Context, in string) (string, error) { // return "next_node_key", nil // } // endNodes := map[string]bool{"path01": true, "path02": true} // branch := compose.NewGraphBranch(condition, endNodes) // // graph.AddBranch("start_node_key", branch) func (g *graph) AddBranch(startNode string, branch *GraphBranch) (err error) { return g.addBranch(startNode, branch, false) } func (g *graph) addBranch(startNode string, branch *GraphBranch, skipData bool) (err error) { if g.buildError != nil { return g.buildError } if g.compiled { return ErrGraphCompiled } defer func() { if err != nil { g.buildError = err } }() if startNode == END { return errors.New("END cannot be a start node") } if _, ok := g.nodes[startNode]; !ok && startNode != START { return fmt.Errorf("branch start node '%s' needs to be added to graph first", startNode) } if _, ok := g.handlerPreBranch[startNode]; !ok { g.handlerPreBranch[startNode] = [][]handlerPair{} } branch.idx = len(g.handlerPreBranch[startNode]) if startNode != START && g.nodes[startNode].executorMeta.component == ComponentOfPassthrough { g.nodes[startNode].cr.inputType = branch.inputType g.nodes[startNode].cr.outputType = branch.inputType g.nodes[startNode].cr.genericHelper = branch.genericHelper.forPredecessorPassthrough() } // check branch condition type result := checkAssignable(g.getNodeOutputType(startNode), branch.inputType) if result == assignableTypeMustNot { return fmt.Errorf("condition's input type[%s] and start node[%s]'s output type[%s] are mismatched", branch.inputType.String(), startNode, g.getNodeOutputType(startNode).String()) } else if result == assignableTypeMay { g.handlerPreBranch[startNode] = append(g.handlerPreBranch[startNode], []handlerPair{branch.inputConverter}) } else { g.handlerPreBranch[startNode] = append(g.handlerPreBranch[startNode], []handlerPair{}) } if !skipData { for endNode := range branch.endNodes { if _, ok := g.nodes[endNode]; !ok { if endNode != END { return fmt.Errorf("branch end node '%s' needs to be added to graph first", endNode) } } g.addToValidateMap(startNode, endNode, nil) e := g.updateToValidateMap() if e != nil { return e } if startNode == START { g.startNodes = append(g.startNodes, endNode) } if endNode == END { g.endNodes = append(g.endNodes, startNode) } } } else { for endNode := range branch.endNodes { if startNode == START { g.startNodes = append(g.startNodes, endNode) } if endNode == END { g.endNodes = append(g.endNodes, startNode) } } branch.noDataFlow = true } g.branches[startNode] = append(g.branches[startNode], branch) return nil } func (g *graph) addToValidateMap(startNode, endNode string, mapping []*FieldMapping) { g.toValidateMap[startNode] = append(g.toValidateMap[startNode], struct { endNode string mappings []*FieldMapping }{endNode: endNode, mappings: mapping}) } // updateToValidateMap after update node, check validate map // check again if nodes in toValidateMap have been updated. because when there are multiple linked passthrough nodes, in the worst scenario, only one node can be updated at a time. func (g *graph) updateToValidateMap() error { var startNodeOutputType, endNodeInputType reflect.Type for { hasChanged := false for startNode := range g.toValidateMap { startNodeOutputType = g.getNodeOutputType(startNode) for i := 0; i < len(g.toValidateMap[startNode]); i++ { endNode := g.toValidateMap[startNode][i] endNodeInputType = g.getNodeInputType(endNode.endNode) if startNodeOutputType == nil && endNodeInputType == nil { continue } // update toValidateMap g.toValidateMap[startNode] = append(g.toValidateMap[startNode][:i], g.toValidateMap[startNode][i+1:]...) i-- hasChanged = true // assume that START and END type isn't empty if startNodeOutputType != nil && endNodeInputType == nil { g.nodes[endNode.endNode].cr.inputType = startNodeOutputType g.nodes[endNode.endNode].cr.outputType = g.nodes[endNode.endNode].cr.inputType g.nodes[endNode.endNode].cr.genericHelper = g.getNodeGenericHelper(startNode).forSuccessorPassthrough() } else if startNodeOutputType == nil /* redundant condition && endNodeInputType != nil */ { g.nodes[startNode].cr.inputType = endNodeInputType g.nodes[startNode].cr.outputType = g.nodes[startNode].cr.inputType g.nodes[startNode].cr.genericHelper = g.getNodeGenericHelper(endNode.endNode).forPredecessorPassthrough() } else if len(endNode.mappings) == 0 { // common node check result := checkAssignable(startNodeOutputType, endNodeInputType) if result == assignableTypeMustNot { return fmt.Errorf("graph edge[%s]-[%s]: start node's output type[%s] and end node's input type[%s] mismatch", startNode, endNode.endNode, startNodeOutputType.String(), endNodeInputType.String()) } else if result == assignableTypeMay { // add runtime check edges if _, ok := g.handlerOnEdges[startNode]; !ok { g.handlerOnEdges[startNode] = make(map[string][]handlerPair) } g.handlerOnEdges[startNode][endNode.endNode] = append(g.handlerOnEdges[startNode][endNode.endNode], g.getNodeGenericHelper(endNode.endNode).inputConverter) } continue } if len(endNode.mappings) > 0 { if _, ok := g.handlerOnEdges[startNode]; !ok { g.handlerOnEdges[startNode] = make(map[string][]handlerPair) } g.fieldMappingRecords[endNode.endNode] = append(g.fieldMappingRecords[endNode.endNode], endNode.mappings...) // field mapping check checker, uncheckedSourcePaths, err := validateFieldMapping(g.getNodeOutputType(startNode), g.getNodeInputType(endNode.endNode), endNode.mappings) if err != nil { return err } g.handlerOnEdges[startNode][endNode.endNode] = append(g.handlerOnEdges[startNode][endNode.endNode], handlerPair{ invoke: func(value any) (any, error) { return fieldMap(endNode.mappings, false, uncheckedSourcePaths)(value) }, transform: streamFieldMap(endNode.mappings, uncheckedSourcePaths), }) if checker != nil { g.handlerOnEdges[startNode][endNode.endNode] = append(g.handlerOnEdges[startNode][endNode.endNode], *checker) } } } } if !hasChanged { break } } return nil } func (g *graph) getNodeGenericHelper(name string) *genericHelper { if name == START { return g.genericHelper.forPredecessorPassthrough() } else if name == END { return g.genericHelper.forSuccessorPassthrough() } return g.nodes[name].getGenericHelper() } func (g *graph) getNodeInputType(name string) reflect.Type { if name == START { return g.inputType() } else if name == END { return g.outputType() } return g.nodes[name].inputType() } func (g *graph) getNodeOutputType(name string) reflect.Type { if name == START { return g.inputType() } else if name == END { return g.outputType() } return g.nodes[name].outputType() } func (g *graph) inputType() reflect.Type { return g.expectedInputType } func (g *graph) outputType() reflect.Type { return g.expectedOutputType } func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composableRunnable, error) { if g.buildError != nil { return nil, g.buildError } // get run type runType := runTypePregel cb := pregelChannelBuilder if isChain(g.cmp) || isWorkflow(g.cmp) { if opt != nil && opt.nodeTriggerMode != "" { return nil, errors.New(fmt.Sprintf("%s doesn't support node trigger mode option", g.cmp)) } } if (opt != nil && opt.nodeTriggerMode == AllPredecessor) || isWorkflow(g.cmp) { runType = runTypeDAG cb = dagChannelBuilder } // get eager type eager := false if isWorkflow(g.cmp) || runType == runTypeDAG { eager = true } if opt != nil && opt.eagerDisabled { eager = false } if len(g.startNodes) == 0 { return nil, errors.New("start node not set") } if len(g.endNodes) == 0 { return nil, errors.New("end node not set") } // toValidateMap isn't empty means there are nodes that cannot infer type for _, v := range g.toValidateMap { if len(v) > 0 { return nil, fmt.Errorf("some node's input or output types cannot be inferred: %v", g.toValidateMap) } } for key := range g.fieldMappingRecords { // not allowed to map multiple fields to the same field toMap := make(map[string]bool) for _, mapping := range g.fieldMappingRecords[key] { if _, ok := toMap[mapping.to]; ok { return nil, fmt.Errorf("duplicate mapping target field: %s of node[%s]", mapping.to, key) } toMap[mapping.to] = true } // add map to input converter g.handlerPreNode[key] = append(g.handlerPreNode[key], g.getNodeGenericHelper(key).inputFieldMappingConverter) } key2SubGraphs := g.beforeChildGraphsCompile(opt) chanSubscribeTo := make(map[string]*chanCall) for name, node := range g.nodes { node.beforeChildGraphCompile(name, key2SubGraphs) r, err := node.compileIfNeeded(ctx) if err != nil { return nil, err } chCall := &chanCall{ action: r, writeTo: g.dataEdges[name], controls: g.controlEdges[name], preProcessor: node.nodeInfo.preProcessor, postProcessor: node.nodeInfo.postProcessor, } branches := g.branches[name] if len(branches) > 0 { branchRuns := make([]*GraphBranch, 0, len(branches)) branchRuns = append(branchRuns, branches...) chCall.writeToBranches = branchRuns } chanSubscribeTo[name] = chCall } dataPredecessors := make(map[string][]string) controlPredecessors := make(map[string][]string) for start, ends := range g.controlEdges { for _, end := range ends { if _, ok := controlPredecessors[end]; !ok { controlPredecessors[end] = []string{start} } else { controlPredecessors[end] = append(controlPredecessors[end], start) } } } for start, ends := range g.dataEdges { for _, end := range ends { if _, ok := dataPredecessors[end]; !ok { dataPredecessors[end] = []string{start} } else { dataPredecessors[end] = append(dataPredecessors[end], start) } } } for start, branches := range g.branches { for _, branch := range branches { for end := range branch.endNodes { if _, ok := controlPredecessors[end]; !ok { controlPredecessors[end] = []string{start} } else { controlPredecessors[end] = append(controlPredecessors[end], start) } if !branch.noDataFlow { if _, ok := dataPredecessors[end]; !ok { dataPredecessors[end] = []string{start} } else { dataPredecessors[end] = append(dataPredecessors[end], start) } } } } } inputChannels := &chanCall{ writeTo: g.dataEdges[START], controls: g.controlEdges[START], writeToBranches: make([]*GraphBranch, len(g.branches[START])), } copy(inputChannels.writeToBranches, g.branches[START]) var mergeConfigs map[string]FanInMergeConfig if opt != nil { mergeConfigs = opt.mergeConfigs } if mergeConfigs == nil { mergeConfigs = make(map[string]FanInMergeConfig) } r := &runner{ chanSubscribeTo: chanSubscribeTo, controlPredecessors: controlPredecessors, dataPredecessors: dataPredecessors, inputChannels: inputChannels, eager: eager, chanBuilder: cb, inputType: g.inputType(), outputType: g.outputType(), genericHelper: g.genericHelper, preBranchHandlerManager: &preBranchHandlerManager{h: g.handlerPreBranch}, preNodeHandlerManager: &preNodeHandlerManager{h: g.handlerPreNode}, edgeHandlerManager: &edgeHandlerManager{h: g.handlerOnEdges}, mergeConfigs: mergeConfigs, } successors := make(map[string][]string) for ch := range r.chanSubscribeTo { successors[ch] = getSuccessors(r.chanSubscribeTo[ch]) } r.successors = successors if g.stateGenerator != nil { r.runCtx = func(ctx context.Context) context.Context { var parent *internalState if p, ok := ctx.Value(stateKey{}).(*internalState); ok { parent = p } return context.WithValue(ctx, stateKey{}, &internalState{ state: g.stateGenerator(ctx), parent: parent, }) } } if runType == runTypeDAG { err := validateDAG(r.chanSubscribeTo, controlPredecessors) if err != nil { return nil, err } r.dag = true } if opt != nil { inputPairs := make(map[string]streamConvertPair) outputPairs := make(map[string]streamConvertPair) for key, c := range r.chanSubscribeTo { inputPairs[key] = c.action.inputStreamConvertPair outputPairs[key] = c.action.outputStreamConvertPair } inputPairs[END] = r.outputConvertStreamPair outputPairs[START] = r.inputConvertStreamPair r.checkPointer = newCheckPointer(inputPairs, outputPairs, opt.checkPointStore, opt.serializer) r.interruptBeforeNodes = opt.interruptBeforeNodes r.interruptAfterNodes = opt.interruptAfterNodes r.options = *opt } // default options if r.dag && r.options.maxRunSteps > 0 { return nil, fmt.Errorf("cannot set max run steps in dag mode") } else if !r.dag && r.options.maxRunSteps == 0 { r.options.maxRunSteps = len(r.chanSubscribeTo) + 10 } g.compiled = true g.onCompileFinish(ctx, opt, key2SubGraphs) return r.toComposableRunnable(), nil } func getSuccessors(c *chanCall) []string { ret := make([]string, len(c.writeTo)) copy(ret, c.writeTo) ret = append(ret, c.controls...) for _, branch := range c.writeToBranches { for node := range branch.endNodes { ret = append(ret, node) } } return uniqueSlice(ret) } func uniqueSlice(s []string) []string { seen := make(map[string]struct{}, len(s)) cur := 0 for i := range s { if _, ok := seen[s[i]]; !ok { seen[s[i]] = struct{}{} s[cur] = s[i] cur++ } } return s[:cur] } type subGraphCompileCallback struct { closure func(ctx context.Context, info *GraphInfo) } // OnFinish is called when the graph is compiled. func (s *subGraphCompileCallback) OnFinish(ctx context.Context, info *GraphInfo) { s.closure(ctx, info) } func (g *graph) beforeChildGraphsCompile(opt *graphCompileOptions) map[string]*GraphInfo { if opt == nil || len(opt.callbacks) == 0 { return nil } return make(map[string]*GraphInfo) } func (gn *graphNode) beforeChildGraphCompile(nodeKey string, key2SubGraphs map[string]*GraphInfo) { if gn.g == nil || key2SubGraphs == nil { return } subGraphCallback := func(ctx2 context.Context, subGraph *GraphInfo) { key2SubGraphs[nodeKey] = subGraph } gn.nodeInfo.compileOption.callbacks = append(gn.nodeInfo.compileOption.callbacks, &subGraphCompileCallback{closure: subGraphCallback}) } func (g *graph) toGraphInfo(opt *graphCompileOptions, key2SubGraphs map[string]*GraphInfo) *GraphInfo { gInfo := &GraphInfo{ CompileOptions: opt.origOpts, Nodes: make(map[string]GraphNodeInfo, len(g.nodes)), Edges: gmap.Clone(g.controlEdges), DataEdges: gmap.Clone(g.dataEdges), Branches: gmap.Map(g.branches, func(startNode string, branches []*GraphBranch) (string, []GraphBranch) { branchInfo := make([]GraphBranch, 0, len(branches)) for _, b := range branches { branchInfo = append(branchInfo, GraphBranch{ invoke: b.invoke, collect: b.collect, inputType: b.inputType, genericHelper: b.genericHelper, endNodes: gmap.Clone(b.endNodes), }) } return startNode, branchInfo }), InputType: g.expectedInputType, OutputType: g.expectedOutputType, Name: opt.graphName, GenStateFn: g.stateGenerator, NewGraphOptions: g.newOpts, } for key := range g.nodes { gNode := g.nodes[key] if gNode.executorMeta.component == ComponentOfPassthrough { gInfo.Nodes[key] = GraphNodeInfo{ Component: gNode.executorMeta.component, GraphAddNodeOpts: gNode.opts, InputType: gNode.cr.inputType, OutputType: gNode.cr.outputType, Name: gNode.nodeInfo.name, InputKey: gNode.cr.nodeInfo.inputKey, OutputKey: gNode.cr.nodeInfo.outputKey, } continue } gNodeInfo := &GraphNodeInfo{ Component: gNode.executorMeta.component, Instance: gNode.instance, GraphAddNodeOpts: gNode.opts, InputType: gNode.cr.inputType, OutputType: gNode.cr.outputType, Name: gNode.nodeInfo.name, InputKey: gNode.cr.nodeInfo.inputKey, OutputKey: gNode.cr.nodeInfo.outputKey, Mappings: g.fieldMappingRecords[key], } if gi, ok := key2SubGraphs[key]; ok { gNodeInfo.GraphInfo = gi } gInfo.Nodes[key] = *gNodeInfo } return gInfo } func (g *graph) onCompileFinish(ctx context.Context, opt *graphCompileOptions, key2SubGraphs map[string]*GraphInfo) { if opt == nil { return } if len(opt.callbacks) == 0 { return } gInfo := g.toGraphInfo(opt, key2SubGraphs) for _, cb := range opt.callbacks { cb.OnFinish(ctx, gInfo) } } func (g *graph) getGenericHelper() *genericHelper { return g.genericHelper } func (g *graph) GetType() string { return "" } func transferTask(script [][]string, invertedEdges map[string][]string) [][]string { utilMap := map[string]bool{} for i := len(script) - 1; i >= 0; i-- { for j := 0; j < len(script[i]); j++ { // deduplicate if _, ok := utilMap[script[i][j]]; ok { script[i] = append(script[i][:j], script[i][j+1:]...) j-- continue } utilMap[script[i][j]] = true target := i for k := i + 1; k < len(script); k++ { hasDependencies := false for l := range script[k] { for _, dependency := range invertedEdges[script[i][j]] { if script[k][l] == dependency { hasDependencies = true break } } if hasDependencies { break } } if hasDependencies { break } target = k } if target != i { script[target] = append(script[target], script[i][j]) script[i] = append(script[i][:j], script[i][j+1:]...) j-- } } } return script } func validateDAG(chanSubscribeTo map[string]*chanCall, controlPredecessors map[string][]string) error { m := map[string]int{} for node := range chanSubscribeTo { if edges, ok := controlPredecessors[node]; ok { m[node] = len(edges) for _, pre := range edges { if pre == START { m[node] -= 1 } } } else { m[node] = 0 } } hasChanged := true for hasChanged { hasChanged = false for node := range m { if m[node] == 0 { hasChanged = true for _, subNode := range chanSubscribeTo[node].controls { if subNode == END { continue } m[subNode]-- } for _, subBranch := range chanSubscribeTo[node].writeToBranches { for subNode := range subBranch.endNodes { if subNode == END { continue } m[subNode]-- } } m[node] = -1 } } } var loopStarts []string for k, v := range m { if v > 0 { loopStarts = append(loopStarts, k) } } if len(loopStarts) > 0 { return fmt.Errorf("%w: %s", DAGInvalidLoopErr, formatLoops(findLoops(loopStarts, chanSubscribeTo))) } return nil } // DAGInvalidLoopErr indicates the graph contains a cycle and is invalid. var DAGInvalidLoopErr = errors.New("DAG is invalid, has loop") func findLoops(startNodes []string, chanCalls map[string]*chanCall) [][]string { controlSuccessors := map[string][]string{} for node, ch := range chanCalls { controlSuccessors[node] = append(controlSuccessors[node], ch.controls...) for _, b := range ch.writeToBranches { for end := range b.endNodes { controlSuccessors[node] = append(controlSuccessors[node], end) } } } visited := map[string]bool{} var dfs func(path []string) [][]string dfs = func(path []string) [][]string { var ret [][]string pathEnd := path[len(path)-1] successors, ok := controlSuccessors[pathEnd] if !ok { return nil } for _, successor := range successors { visited[successor] = true if successor == END { continue } var looped bool for i, node := range path { if node == successor { ret = append(ret, append(path[i:], successor)) looped = true break } } if looped { continue } ret = append(ret, dfs(append(path, successor))...) } return ret } var ret [][]string for _, node := range startNodes { if !visited[node] { ret = append(ret, dfs([]string{node})...) } } return ret } func formatLoops(loops [][]string) string { sb := strings.Builder{} for _, loop := range loops { if len(loop) == 0 { continue } sb.WriteString("[") sb.WriteString(loop[0]) for i := 1; i < len(loop); i++ { sb.WriteString("->") sb.WriteString(loop[i]) } sb.WriteString("]") } return sb.String() } // NewNodePath specifies a path to a node in the graph, which is composed of node keys. // Starting from the top graph, // following this set of node keys can lead to a specific node in the top graph or a subgraph. // // e.g. // NewNodePath("sub_graph_node_key", "node_key_within_sub_graph") func NewNodePath(nodeKeyPath ...string) *NodePath { return &NodePath{path: nodeKeyPath} } // NodePath represents a path composed of node keys to locate a node. type NodePath struct { path []string } // GetPath returns the sequence of node keys in the path. func (p *NodePath) GetPath() []string { return p.path } ================================================ FILE: compose/graph_add_node_options.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "reflect" "github.com/cloudwego/eino/internal/generic" ) type graphAddNodeOpts struct { nodeOptions *nodeOptions processor *processorOpts needState bool } // GraphAddNodeOpt is a functional option type for adding a node to a graph. // e.g. // // graph.AddNode("node_name", node, compose.WithInputKey("input_key"), compose.WithOutputKey("output_key")) type GraphAddNodeOpt func(o *graphAddNodeOpts) type nodeOptions struct { nodeName string nodeKey string inputKey string outputKey string graphCompileOption []GraphCompileOption // when this node is itself an AnyGraph, this option will be used to compile the node as a nested graph } // WithNodeName sets the name of the node. func WithNodeName(n string) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.nodeOptions.nodeName = n } } // WithNodeKey set the node key, which is used to identify the node in the chain. // only for use in Chain/StateChain. func WithNodeKey(key string) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.nodeOptions.nodeKey = key } } // WithInputKey sets the input key of the node. // this will change the input value of the node, for example, if the pre node's output is map[string]any{"key01": "value01"}, // and the current node's input key is "key01", then the current node's input value will be "value01". func WithInputKey(k string) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.nodeOptions.inputKey = k } } // WithOutputKey sets the output key of the node. // this will change the output value of the node, for example, if the current node's output key is "key01", // then the node's output value will be map[string]any{"key01": value}. func WithOutputKey(k string) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.nodeOptions.outputKey = k } } // WithGraphCompileOptions when the node is an AnyGraph, use this option to set compile option for the node. // e.g. // // graph.AddNode("node_name", node, compose.WithGraphCompileOptions(compose.WithGraphName("my_sub_graph"))) func WithGraphCompileOptions(opts ...GraphCompileOption) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.nodeOptions.graphCompileOption = opts } } // WithStatePreHandler modify node's input of I according to state S and input or store input information into state, and it's thread-safe. // notice: this option requires Graph to be created with WithGenLocalState option. // I: input type of the Node like ChatModel, Lambda, Retriever etc. // S: state type defined in WithGenLocalState func WithStatePreHandler[I, S any](pre StatePreHandler[I, S]) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.processor.statePreHandler = convertPreHandler(pre) o.processor.preStateType = generic.TypeOf[S]() o.needState = true } } // WithStatePostHandler modify node's output of O according to state S and output or store output information into state, and it's thread-safe. // notice: this option requires Graph to be created with WithGenLocalState option. // O: output type of the Node like ChatModel, Lambda, Retriever etc. // S: state type defined in WithGenLocalState func WithStatePostHandler[O, S any](post StatePostHandler[O, S]) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.processor.statePostHandler = convertPostHandler(post) o.processor.postStateType = generic.TypeOf[S]() o.needState = true } } // WithStreamStatePreHandler modify node's streaming input of I according to state S and input or store input information into state, and it's thread-safe. // notice: this option requires Graph to be created with WithGenLocalState option. // when to use: when upstream node's output is an actual stream, and you want the current node's input to remain an actual stream after state pre handler. // caution: while StreamStatePreHandler is thread safe, modifying state within your own goroutine is NOT. // I: input type of the Node like ChatModel, Lambda, Retriever etc. // S: state type defined in WithGenLocalState func WithStreamStatePreHandler[I, S any](pre StreamStatePreHandler[I, S]) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.processor.statePreHandler = streamConvertPreHandler(pre) o.processor.preStateType = generic.TypeOf[S]() o.needState = true } } // WithStreamStatePostHandler modify node's streaming output of O according to state S and output or store output information into state, and it's thread-safe. // notice: this option requires Graph to be created with WithGenLocalState option. // when to use: when current node's output is an actual stream, and you want the downstream node's input to remain an actual stream after state post handler. // caution: while StreamStatePostHandler is thread safe, modifying state within your own goroutine is NOT. // O: output type of the Node like ChatModel, Lambda, Retriever etc. // S: state type defined in WithGenLocalState func WithStreamStatePostHandler[O, S any](post StreamStatePostHandler[O, S]) GraphAddNodeOpt { return func(o *graphAddNodeOpts) { o.processor.statePostHandler = streamConvertPostHandler(post) o.processor.postStateType = generic.TypeOf[S]() o.needState = true } } type processorOpts struct { statePreHandler *composableRunnable preStateType reflect.Type // used for type validation statePostHandler *composableRunnable postStateType reflect.Type // used for type validation } func getGraphAddNodeOpts(opts ...GraphAddNodeOpt) *graphAddNodeOpts { opt := &graphAddNodeOpts{ nodeOptions: &nodeOptions{ nodeName: "", nodeKey: "", }, processor: &processorOpts{ statePreHandler: nil, statePostHandler: nil, }, } for _, fn := range opts { fn(opt) } return opt } ================================================ FILE: compose/graph_call_options.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "reflect" "time" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" ) type graphCancelChanKey struct{} type graphCancelChanVal struct { ch chan *time.Duration } type graphInterruptOptions struct { timeout *time.Duration } // GraphInterruptOption configures behavior when interrupting a running graph. type GraphInterruptOption func(o *graphInterruptOptions) // WithGraphInterruptTimeout specifies the max waiting time before generating an interrupt. // After the max waiting time, the graph will force an interrupt. Any unfinished tasks will be re-run when the graph is resumed. func WithGraphInterruptTimeout(timeout time.Duration) GraphInterruptOption { return func(o *graphInterruptOptions) { o.timeout = &timeout } } // WithGraphInterrupt creates a context with graph cancellation support. // When the returned context is used to invoke a graph or workflow, calling the interrupt function will trigger an interrupt. // The graph will wait for current tasks to complete by default. // // Input Persistence: When WithGraphInterrupt is used, ALL nodes (in both root graph and subgraphs) will automatically // persist their inputs (both streaming and non-streaming) before execution. If the graph is interrupted, these inputs // are restored when the graph resumes from a checkpoint, ensuring interrupted nodes receive their original inputs. // // This behavior differs from internal interrupts triggered via compose.Interrupt() within a node's function body. // Internal interrupts do NOT automatically persist inputs - the node author must manage input persistence manually, // either by saving it in the global graph state or using compose.StatefulInterrupt() to store it in local interrupt state. // WithGraphInterrupt enables automatic input persistence because external interrupts can occur at any point during // node execution, making it impossible for the node to prepare for the interrupt. // // Why input persistence is not enabled by default for internal interrupts: Enabling it universally would break // existing code that relies on checking "input == nil" to determine whether the node is running for the first time // or resuming from an interrupt. The recommended approach is to use compose.GetInterruptState() to explicitly // determine whether the current execution is a first run or a resume. func WithGraphInterrupt(parent context.Context) (ctx context.Context, interrupt func(opts ...GraphInterruptOption)) { ch := make(chan *time.Duration, 1) ctx = context.WithValue(parent, graphCancelChanKey{}, &graphCancelChanVal{ ch: ch, }) return ctx, func(opts ...GraphInterruptOption) { o := &graphInterruptOptions{} for _, opt := range opts { opt(o) } ch <- o.timeout close(ch) } } func getGraphCancel(ctx context.Context) *graphCancelChanVal { val, ok := ctx.Value(graphCancelChanKey{}).(*graphCancelChanVal) if !ok { return nil } return val } // Option is a functional option type for calling a graph. type Option struct { options []any handler []callbacks.Handler paths []*NodePath maxRunSteps int checkPointID *string writeToCheckPointID *string forceNewRun bool stateModifier StateModifier } func (o Option) deepCopy() Option { nOptions := make([]any, len(o.options)) copy(nOptions, o.options) nHandler := make([]callbacks.Handler, len(o.handler)) copy(nHandler, o.handler) nPaths := make([]*NodePath, len(o.paths)) for i, path := range o.paths { nPath := *path nPaths[i] = &nPath } return Option{ options: nOptions, handler: nHandler, paths: nPaths, maxRunSteps: o.maxRunSteps, } } // DesignateNode sets the key of the node to which the option will be applied. // notice: only effective at the top graph. // e.g. // // embeddingOption := compose.WithEmbeddingOption(embedding.WithModel("text-embedding-3-small")) // runnable.Invoke(ctx, "input", embeddingOption.DesignateNode("embedding_node_key")) func (o Option) DesignateNode(nodeKey ...string) Option { nKeys := make([]*NodePath, len(nodeKey)) for i, k := range nodeKey { nKeys[i] = NewNodePath(k) } return o.DesignateNodeWithPath(nKeys...) } // DesignateNodeWithPath sets the path of the node(s) to which the option will be applied. // You can specify a node in the subgraph through `NodePath` to make the option only take effect at this node. // // e.g. // nodePath := NewNodePath("sub_graph_node_key", "node_key_within_sub_graph") // DesignateNodeWithPath(nodePath) func (o Option) DesignateNodeWithPath(path ...*NodePath) Option { o.paths = append(o.paths, path...) return o } // WithEmbeddingOption is a functional option type for embedding component. // e.g. // // embeddingOption := compose.WithEmbeddingOption(embedding.WithModel("text-embedding-3-small")) // runnable.Invoke(ctx, "input", embeddingOption) func WithEmbeddingOption(opts ...embedding.Option) Option { return withComponentOption(opts...) } // WithRetrieverOption is a functional option type for retriever component. // e.g. // // retrieverOption := compose.WithRetrieverOption(retriever.WithIndex("my_index")) // runnable.Invoke(ctx, "input", retrieverOption) func WithRetrieverOption(opts ...retriever.Option) Option { return withComponentOption(opts...) } // WithLoaderOption is a functional option type for loader component. // e.g. // // loaderOption := compose.WithLoaderOption(document.WithCollection("my_collection")) // runnable.Invoke(ctx, "input", loaderOption) func WithLoaderOption(opts ...document.LoaderOption) Option { return withComponentOption(opts...) } // WithDocumentTransformerOption is a functional option type for document transformer component. func WithDocumentTransformerOption(opts ...document.TransformerOption) Option { return withComponentOption(opts...) } // WithIndexerOption is a functional option type for indexer component. // e.g. // // indexerOption := compose.WithIndexerOption(indexer.WithSubIndexes([]string{"my_sub_index"})) // runnable.Invoke(ctx, "input", indexerOption) func WithIndexerOption(opts ...indexer.Option) Option { return withComponentOption(opts...) } // WithChatModelOption is a functional option type for chat model component. // e.g. // // chatModelOption := compose.WithChatModelOption(model.WithTemperature(0.7)) // runnable.Invoke(ctx, "input", chatModelOption) func WithChatModelOption(opts ...model.Option) Option { return withComponentOption(opts...) } // WithChatTemplateOption is a functional option type for chat template component. func WithChatTemplateOption(opts ...prompt.Option) Option { return withComponentOption(opts...) } // WithToolsNodeOption is a functional option type for tools node component. func WithToolsNodeOption(opts ...ToolsNodeOption) Option { return withComponentOption(opts...) } // WithLambdaOption is a functional option type for lambda component. func WithLambdaOption(opts ...any) Option { return Option{ options: opts, paths: make([]*NodePath, 0), } } // WithCallbacks set callback handlers for all components in a single call. // e.g. // // runnable.Invoke(ctx, "input", compose.WithCallbacks(&myCallbacks{})) func WithCallbacks(cbs ...callbacks.Handler) Option { return Option{ handler: cbs, } } // WithRuntimeMaxSteps sets the maximum number of steps for the graph runtime. // e.g. // // runnable.Invoke(ctx, "input", compose.WithRuntimeMaxSteps(20)) func WithRuntimeMaxSteps(maxSteps int) Option { return Option{ maxRunSteps: maxSteps, } } func withComponentOption[TOption any](opts ...TOption) Option { o := make([]any, 0, len(opts)) for i := range opts { o = append(o, opts[i]) } return Option{ options: o, paths: make([]*NodePath, 0), } } func convertOption[TOption any](opts ...any) ([]TOption, error) { if len(opts) == 0 { return nil, nil } ret := make([]TOption, 0, len(opts)) for i := range opts { o, ok := opts[i].(TOption) if !ok { return nil, fmt.Errorf("unexpected component option type, expected:%s, actual:%s", reflect.TypeOf((*TOption)(nil)).Elem().String(), reflect.TypeOf(opts[i]).String()) } ret = append(ret, o) } return ret, nil } ================================================ FILE: compose/graph_call_options_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/retriever" mockDocument "github.com/cloudwego/eino/internal/mock/components/document" mockEmbedding "github.com/cloudwego/eino/internal/mock/components/embedding" mockRetriever "github.com/cloudwego/eino/internal/mock/components/retriever" "github.com/cloudwego/eino/schema" ) var optionSuccess = true var idx int func checkOption(opts ...model.Option) bool { if len(opts) != 2 { return false } o := model.GetCommonOptions(&model.Options{}, opts...) if o.TopP == nil || *o.TopP != 1.0 { return false } if o.Model == nil { return false } if idx == 0 { idx = 1 if o.Model == nil || *o.Model != "123" { return false } } else { idx = 0 if o.Model == nil || *o.Model != "456" { return false } } return true } type testModel struct{} func (t *testModel) BindTools(tools []*schema.ToolInfo) error { return nil } func (t *testModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { if !checkOption(opts...) { optionSuccess = false } return &schema.Message{}, nil } func (t *testModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { if !checkOption(opts...) { optionSuccess = false } sr, sw := schema.Pipe[*schema.Message](1) sw.Send(nil, nil) sw.Close() return sr, nil } func TestCallOption(t *testing.T) { g := NewGraph[[]*schema.Message, *schema.Message]() err := g.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input []*schema.Message, opts ...string) (output []*schema.Message, err error) { if len(opts) != 1 || opts[0] != "1" { t.Fatalf("lambda option length isn't 1 or content isn't '1': %v", opts) } return input, nil })) assert.Nil(t, err) err = g.AddChatModelNode("2", &testModel{}) assert.Nil(t, err) err = g.AddLambdaNode("-", InvokableLambda(func(ctx context.Context, input *schema.Message) (output []*schema.Message, err error) { return []*schema.Message{input}, nil })) assert.Nil(t, err) err = g.AddChatModelNode("3", &testModel{}) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "1") if err != nil { t.Fatal(err) } err = g.AddEdge("1", "2") assert.Nil(t, err) err = g.AddEdge("2", "-") assert.Nil(t, err) err = g.AddEdge("-", "3") assert.Nil(t, err) err = g.AddEdge("3", END) assert.Nil(t, err) ctx := context.Background() r, err := g.Compile(ctx) assert.Nil(t, err) sessionKey := struct{}{} startCnt := 0 endCnt := 0 opts := []Option{ WithChatModelOption( model.WithModel("123"), ).DesignateNode("2"), WithChatModelOption( model.WithModel("456"), ).DesignateNode("3"), WithChatModelOption( model.WithTopP(1.0), ), WithCallbacks(callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { startCnt++ return context.WithValue(ctx, sessionKey, "start") }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { if ctx.Value(sessionKey).(string) == "start" { endCnt++ return context.WithValue(ctx, sessionKey, "end") } return ctx }).Build()).DesignateNode("3"), WithLambdaOption("1").DesignateNode("1"), } _, err = r.Invoke(ctx, []*schema.Message{}, opts...) if err != nil { t.Fatal(err) } if !optionSuccess { t.Fatal("invoke option fail") } if startCnt != 1 { t.Fatal("node callback fail") } if endCnt != 1 { t.Fatal("node callback fail") } _, err = r.Stream(ctx, []*schema.Message{}, opts...) if err != nil { t.Fatal(err) } if !optionSuccess { t.Fatal("stream option fail") } srOfCollect, swOfCollect := schema.Pipe[[]*schema.Message](1) swOfCollect.Send([]*schema.Message{}, nil) swOfCollect.Close() _, err = r.Collect(ctx, srOfCollect, opts...) assert.Nil(t, err) if !optionSuccess { t.Fatal("collect option fail") } srOfTransform, swOfTransform := schema.Pipe[[]*schema.Message](1) swOfTransform.Send([]*schema.Message{}, nil) swOfTransform.Close() _, err = r.Transform(ctx, srOfTransform, opts...) assert.Nil(t, err) if !optionSuccess { t.Fatal("transform option fail") } } func TestCallOptionsOneByOne(t *testing.T) { ctx := context.Background() t.Run("common_option", func(t *testing.T) { type option struct { uid int64 } opt := withComponentOption(&option{uid: 100}) assert.Len(t, opt.options, 1) assert.IsType(t, &option{}, opt.options[0]) assert.Equal(t, &option{uid: 100}, opt.options[0]) }) t.Run("embedding_option", func(t *testing.T) { ctrl := gomock.NewController(t) inst := mockEmbedding.NewMockEmbedder(ctrl) var opt *embedding.Options inst.EXPECT().EmbedStrings(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { opt = embedding.GetCommonOptions(&embedding.Options{}, opts...) return nil, nil }).Times(1) ch := NewChain[map[string]any, map[string]any]() ch.AppendEmbedding(inst, WithInputKey("input"), WithOutputKey("output")) r, err := ch.Compile(ctx) assert.NoError(t, err) outs, err := r.Invoke(ctx, map[string]any{"input": []string{}}, WithEmbeddingOption(embedding.WithModel("123")), ) assert.NoError(t, err) assert.Contains(t, outs, "output") assert.NotNil(t, opt.Model) assert.Equal(t, "123", *opt.Model) }) t.Run("retriever_option", func(t *testing.T) { ctrl := gomock.NewController(t) inst := mockRetriever.NewMockRetriever(ctrl) var opt *retriever.Options inst.EXPECT().Retrieve(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { opt = retriever.GetCommonOptions(&retriever.Options{}, opts...) return nil, nil }). Times(1) ch := NewChain[map[string]any, map[string]any]() ch.AppendRetriever(inst, WithInputKey("input"), WithOutputKey("output")) r, err := ch.Compile(ctx) assert.NoError(t, err) outs, err := r.Invoke(ctx, map[string]any{"input": "hi"}, WithRetrieverOption(retriever.WithIndex("123")), ) assert.NoError(t, err) assert.Contains(t, outs, "output") assert.NotNil(t, opt.Index) assert.Equal(t, "123", *opt.Index) }) t.Run("loader_option", func(t *testing.T) { ctrl := gomock.NewController(t) inst := mockDocument.NewMockLoader(ctrl) type implOption struct { uid int64 } type implOptFn func(o *implOption) withUID := func(uid int64) document.LoaderOption { return document.WrapLoaderImplSpecificOptFn[implOption](func(i *implOption) { i.uid = uid }) } var opt *implOption inst.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, src document.Source, opts ...document.LoaderOption) ([]*schema.Document, error) { opt = document.GetLoaderImplSpecificOptions[implOption](&implOption{uid: 1}, opts...) return nil, nil }). Times(1) ch := NewChain[map[string]any, map[string]any]() ch.AppendLoader(inst, WithInputKey("input"), WithOutputKey("output")) r, err := ch.Compile(ctx) assert.NoError(t, err) outs, err := r.Invoke(ctx, map[string]any{"input": document.Source{}}, WithLoaderOption(withUID(123)), ) assert.NoError(t, err) assert.Contains(t, outs, "output") assert.Equal(t, int64(123), opt.uid) }) } func TestCallOptionInSubGraph(t *testing.T) { ctx := context.Background() type child1Option string type child2Option string type parentOption string type grandparentOption string child1 := NewGraph[string, string]() err := child1.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...child1Option) (output string, err error) { if len(opts) != 1 || opts[0] != "child1-1" { t.Fatal("child1-1 option error") } return input + " child1-1", nil }), WithNodeName("child1-1")) assert.NoError(t, err) err = child1.AddEdge(START, "1") assert.NoError(t, err) err = child1.AddEdge("1", END) assert.NoError(t, err) child2 := NewGraph[string, string]() err = child2.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...child2Option) (output string, err error) { if len(opts) != 1 || opts[0] != "child2-1" { t.Fatal("child2-1 option error") } return input + " child2-1", nil }), WithNodeName("child2-1")) assert.NoError(t, err) err = child2.AddEdge(START, "1") assert.NoError(t, err) err = child2.AddEdge("1", END) assert.NoError(t, err) parent := NewGraph[string, string]() err = parent.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...parentOption) (output string, err error) { if len(opts) != 1 || opts[0] != "parent-1" { t.Fatal("parent-1 option error") } return input + " parent-1", nil }), WithNodeName("parent-1")) assert.NoError(t, err) err = parent.AddGraphNode("2", child1, WithNodeName("child1")) assert.NoError(t, err) err = parent.AddGraphNode("3", child2, WithNodeName("child2")) assert.NoError(t, err) err = parent.AddEdge(START, "1") assert.NoError(t, err) err = parent.AddEdge("1", "2") assert.NoError(t, err) err = parent.AddEdge("2", "3") assert.NoError(t, err) err = parent.AddEdge("3", END) assert.NoError(t, err) grandParent := NewGraph[string, string]() err = grandParent.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...grandparentOption) (output string, err error) { if len(opts) != 1 || opts[0] != "grandparent-1" { t.Fatal("grandparent-1 option error") } return input + " grandparent-1", nil }), WithNodeName("grandparent-1")) assert.NoError(t, err) err = grandParent.AddGraphNode("2", parent, WithNodeName("parent")) assert.NoError(t, err) err = grandParent.AddEdge(START, "1") assert.NoError(t, err) err = grandParent.AddEdge("1", "2") assert.NoError(t, err) err = grandParent.AddEdge("2", END) assert.NoError(t, err) r, err := grandParent.Compile(ctx, WithGraphName("grandparent")) assert.NoError(t, err) grandCommonTimes := 0 grandCommonCB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { switch grandCommonTimes { case 0: if info.Name != "grandparent" || info.Component != ComponentOfGraph { t.Fatal("grandparent common callback 0 error") } case 1: if info.Name != "grandparent-1" { t.Fatal("grandparent common callback 1 error") } case 2: if info.Name != "parent" { t.Fatal("grandparent common callback 2 error") } case 3: if info.Name != "parent-1" { t.Fatal("grandparent common callback 3 error") } case 4: if info.Name != "child1" { t.Fatal("grandparent common callback 4 error") } case 5: if info.Name != "child1-1" { t.Fatal("grandparent common callback 5 error") } case 6: if info.Name != "child2" { t.Fatal("grandparent common callback 6 error") } case 7: if info.Name != "child2-1" { t.Fatal("grandparent common callback 7 error") } default: t.Fatal("grandparent common callback too many") } grandCommonTimes++ return ctx }).Build() grand1CB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Name != "grandparent-1" { t.Fatal("grandparent common callback 0 error") } return ctx }).Build() parentCommonCBTimes := 0 parentCommonCB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { switch parentCommonCBTimes { case 0: if info.Name != "parent" { t.Fatal("parent common callback 0 error") } case 1: if info.Name != "parent-1" { t.Fatal("parent common callback 1 error") } case 2: if info.Name != "child1" { t.Fatal("parent common callback 2 error") } case 3: if info.Name != "child1-1" { t.Fatal("parent common callback 3 error") } case 4: if info.Name != "child2" { t.Fatal("parent common callback 4 error") } case 5: if info.Name != "child2-1" { t.Fatal("parent common callback 5 error") } default: t.Fatal("parent common callback too many") } parentCommonCBTimes++ return ctx }).Build() child1CommonCBTimes := 0 child1CommonCB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { switch child1CommonCBTimes { case 0: if info.Name != "child1" { t.Fatal("child1 common callback 0 error") } case 1: if info.Name != "child1-1" { t.Fatal("child1 common callback 1 error") } default: t.Fatal("child1 common callback too many") } child1CommonCBTimes++ return ctx }).Build() child2CB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Name != "child2-1" { t.Fatal("child2-1 common callback 0 error") } return ctx }).Build() result, err := r.Invoke(ctx, "input", WithCallbacks(grandCommonCB), WithCallbacks(parentCommonCB).DesignateNodeWithPath(NewNodePath("2")), WithCallbacks(grand1CB).DesignateNode("1"), WithCallbacks(child1CommonCB).DesignateNodeWithPath(NewNodePath("2", "2")), WithCallbacks(child2CB).DesignateNodeWithPath(NewNodePath("2", "3", "1")), WithLambdaOption(grandparentOption("grandparent-1")).DesignateNodeWithPath(NewNodePath("1")), WithLambdaOption(parentOption("parent-1")).DesignateNodeWithPath(NewNodePath("2", "1")), WithLambdaOption(child1Option("child1-1")).DesignateNodeWithPath(NewNodePath("2", "2", "1")), WithLambdaOption(child2Option("child2-1")).DesignateNodeWithPath(NewNodePath("2", "3", "1")), ) assert.NoError(t, err) assert.Equal(t, result, "input grandparent-1 parent-1 child1-1 child2-1") } ================================================ FILE: compose/graph_compile_options.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose type graphCompileOptions struct { maxRunSteps int graphName string nodeTriggerMode NodeTriggerMode // default to AnyPredecessor (pregel) callbacks []GraphCompileCallback origOpts []GraphCompileOption checkPointStore CheckPointStore serializer Serializer interruptBeforeNodes []string interruptAfterNodes []string eagerDisabled bool mergeConfigs map[string]FanInMergeConfig } func newGraphCompileOptions(opts ...GraphCompileOption) *graphCompileOptions { option := &graphCompileOptions{} for _, o := range opts { o(option) } option.origOpts = opts return option } // GraphCompileOption options for compiling AnyGraph. type GraphCompileOption func(*graphCompileOptions) // WithMaxRunSteps sets the maximum number of steps that a graph can run. // This is useful to prevent infinite loops in graphs with cycles. // If the number of steps exceeds maxSteps, the graph execution will be terminated with an error. func WithMaxRunSteps(maxSteps int) GraphCompileOption { return func(o *graphCompileOptions) { o.maxRunSteps = maxSteps } } // WithGraphName sets a name for the graph. // The name is used for debugging and logging purposes. // If not set, a default name will be used. func WithGraphName(graphName string) GraphCompileOption { return func(o *graphCompileOptions) { o.graphName = graphName } } // WithEagerExecution enables the eager execution mode for the graph. // In eager mode, nodes will be executed immediately once they are ready to run, // without waiting for the completion of a super step, ref: https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/orchestration_design_principles/#runtime-engine // Note: Eager mode is not allowed when the graph's trigger mode is set to AnyPredecessor. // Workflow uses eager mode by default. // Deprecated: Eager execution is automatically enabled by default when a node's trigger mode is set to AllPredecessor. // If you were using this option previously, it can be safely removed without changing behavior. func WithEagerExecution() GraphCompileOption { return func(o *graphCompileOptions) { return } } // WithEagerExecutionDisabled disables the eager execution mode for the graph. // By default, eager execution is enabled for Workflow and Graph with the AllPredecessor trigger mode. // After using this option, nodes will wait for the completion of a super step instead of execute immediately once they are ready to run. // ref: https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/orchestration_design_principles/#runtime-engine func WithEagerExecutionDisabled() GraphCompileOption { return func(o *graphCompileOptions) { o.eagerDisabled = true } } // WithNodeTriggerMode sets the trigger mode for nodes in the graph. // The trigger mode determines when a node is triggered during graph execution, ref: https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/orchestration_design_principles/#runtime-engine // AnyPredecessor by default. func WithNodeTriggerMode(triggerMode NodeTriggerMode) GraphCompileOption { return func(o *graphCompileOptions) { o.nodeTriggerMode = triggerMode } } // WithGraphCompileCallbacks sets callbacks for graph compilation. func WithGraphCompileCallbacks(cbs ...GraphCompileCallback) GraphCompileOption { return func(o *graphCompileOptions) { o.callbacks = append(o.callbacks, cbs...) } } // FanInMergeConfig defines the configuration for fan-in merge operations. // It allows specifying how multiple inputs are merged into a single input. // StreamMergeWithSourceEOF indicates whether to emit a SourceEOF error for each stream // when it ends, before the final merged output is produced. This is useful for // tracking the completion of individual input streams in a named stream merge. type FanInMergeConfig struct { StreamMergeWithSourceEOF bool //indicates whether to emit a SourceEOF error for each stream } // WithFanInMergeConfig sets the fan-in merge configurations // for the graph nodes that receive inputs from multiple sources. func WithFanInMergeConfig(confs map[string]FanInMergeConfig) GraphCompileOption { return func(o *graphCompileOptions) { o.mergeConfigs = confs } } // InitGraphCompileCallbacks set global graph compile callbacks, // which ONLY will be added to top level graph compile options func InitGraphCompileCallbacks(cbs []GraphCompileCallback) { globalGraphCompileCallbacks = cbs } var globalGraphCompileCallbacks []GraphCompileCallback ================================================ FILE: compose/graph_manager.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "runtime/debug" "time" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/safe" ) type channel interface { reportValues(map[string]any) error reportDependencies([]string) reportSkip([]string) bool get(bool, string, *edgeHandlerManager) (any, bool, error) convertValues(fn func(map[string]any) error) error load(channel) error setMergeConfig(FanInMergeConfig) } type edgeHandlerManager struct { h map[string]map[string][]handlerPair } func (e *edgeHandlerManager) handle(from, to string, value any, isStream bool) (any, error) { if _, ok := e.h[from]; !ok { return value, nil } if _, ok := e.h[from][to]; !ok { return value, nil } if isStream { for _, v := range e.h[from][to] { value = v.transform(value.(streamReader)) } } else { for _, v := range e.h[from][to] { var err error value, err = v.invoke(value) if err != nil { return nil, err } } } return value, nil } type preNodeHandlerManager struct { h map[string][]handlerPair } func (p *preNodeHandlerManager) handle(nodeKey string, value any, isStream bool) (any, error) { if _, ok := p.h[nodeKey]; !ok { return value, nil } if isStream { for _, v := range p.h[nodeKey] { value = v.transform(value.(streamReader)) } } else { for _, v := range p.h[nodeKey] { var err error value, err = v.invoke(value) if err != nil { return nil, err } } } return value, nil } type preBranchHandlerManager struct { h map[string][][]handlerPair } func (p *preBranchHandlerManager) handle(nodeKey string, idx int, value any, isStream bool) (any, error) { if _, ok := p.h[nodeKey]; !ok { return value, nil } if isStream { for _, v := range p.h[nodeKey][idx] { value = v.transform(value.(streamReader)) } } else { for _, v := range p.h[nodeKey][idx] { var err error value, err = v.invoke(value) if err != nil { return nil, err } } } return value, nil } type channelManager struct { isStream bool channels map[string]channel successors map[string][]string dataPredecessors map[string]map[string]struct{} controlPredecessors map[string]map[string]struct{} edgeHandlerManager *edgeHandlerManager preNodeHandlerManager *preNodeHandlerManager } func (c *channelManager) loadChannels(channels map[string]channel) error { for key, ch := range c.channels { if nCh, ok := channels[key]; ok { if err := ch.load(nCh); err != nil { return fmt.Errorf("load channel[%s] fail: %w", key, err) } } } return nil } func (c *channelManager) updateValues(_ context.Context, values map[string] /*to*/ map[string] /*from*/ any) error { for target, fromMap := range values { toChannel, ok := c.channels[target] if !ok { return fmt.Errorf("target channel doesn't existed: %s", target) } dps, ok := c.dataPredecessors[target] if !ok { dps = map[string]struct{}{} } nFromMap := make(map[string]any, len(fromMap)) for from, value := range fromMap { if _, ok = dps[from]; ok { nFromMap[from] = fromMap[from] } else { if sr, okk := value.(streamReader); okk { sr.close() } } } err := toChannel.reportValues(nFromMap) if err != nil { return fmt.Errorf("update target channel[%s] fail: %w", target, err) } } return nil } func (c *channelManager) updateDependencies(_ context.Context, dependenciesMap map[string][]string) error { for target, dependencies := range dependenciesMap { toChannel, ok := c.channels[target] if !ok { return fmt.Errorf("target channel doesn't existed: %s", target) } cps, ok := c.controlPredecessors[target] if !ok { cps = map[string]struct{}{} } var deps []string for _, from := range dependencies { if _, ok = cps[from]; ok { deps = append(deps, from) } } toChannel.reportDependencies(deps) } return nil } func (c *channelManager) getFromReadyChannels(_ context.Context) (map[string]any, error) { result := make(map[string]any) for target, ch := range c.channels { v, ready, err := ch.get(c.isStream, target, c.edgeHandlerManager) if err != nil { return nil, fmt.Errorf("get value from ready channel[%s] fail: %w", target, err) } if ready { v, err = c.preNodeHandlerManager.handle(target, v, c.isStream) if err != nil { return nil, err } result[target] = v } } return result, nil } func (c *channelManager) updateAndGet(ctx context.Context, values map[string]map[string]any, dependencies map[string][]string) (map[string]any, error) { err := c.updateValues(ctx, values) if err != nil { return nil, fmt.Errorf("update channel fail: %w", err) } err = c.updateDependencies(ctx, dependencies) if err != nil { return nil, fmt.Errorf("update channel fail: %w", err) } return c.getFromReadyChannels(ctx) } func (c *channelManager) reportBranch(from string, skippedNodes []string) error { var nKeys []string for _, node := range skippedNodes { skipped := c.channels[node].reportSkip([]string{from}) if skipped { nKeys = append(nKeys, node) } } for i := 0; i < len(nKeys); i++ { key := nKeys[i] if key == END { continue } if _, ok := c.successors[key]; !ok { return fmt.Errorf("unknown node: %s", key) } for _, successor := range c.successors[key] { skipped := c.channels[successor].reportSkip([]string{key}) if skipped { nKeys = appendIfNotExist(nKeys, successor) } // todo: detect if end node has been skipped? } } return nil } func appendIfNotExist(s []string, elem string) []string { for _, i := range s { if i == elem { return s } } return append(s, elem) } type task struct { ctx context.Context nodeKey string call *chanCall input any originalInput any output any option []any err error skipPreHandler bool } type taskManager struct { runWrapper runnableCallWrapper opts []Option needAll bool num uint32 done *internal.UnboundedChan[*task] runningTasks map[string]*task cancelCh chan *time.Duration canceled bool deadline *time.Time persistRerunInput bool } func (t *taskManager) execute(currentTask *task) { defer func() { panicInfo := recover() if panicInfo != nil { currentTask.output = nil currentTask.err = safe.NewPanicErr(panicInfo, debug.Stack()) } t.done.Send(currentTask) }() ctx := initNodeCallbacks(currentTask.ctx, currentTask.nodeKey, currentTask.call.action.nodeInfo, currentTask.call.action.meta, t.opts...) currentTask.output, currentTask.err = t.runWrapper(ctx, currentTask.call.action, currentTask.input, currentTask.option...) } func (t *taskManager) submit(tasks []*task) error { if len(tasks) == 0 { return nil } // synchronously execute one task, if there are no other tasks in the task pool and meet one of the following conditions: // 1. the new task is the only one // 2. the task manager mode is set to needAll for i := 0; i < len(tasks); i++ { currentTask := tasks[i] if t.persistRerunInput { if sr, ok := currentTask.input.(streamReader); ok { copies := sr.copy(2) currentTask.originalInput, currentTask.input = copies[0], copies[1] } else { currentTask.originalInput = currentTask.input } } err := runPreHandler(currentTask, t.runWrapper) if err != nil { // pre-handler error, regarded as a failure of the task itself currentTask.err = err tasks = append(tasks[:i], tasks[i+1:]...) i-- t.num++ t.done.Send(currentTask) } t.runningTasks[currentTask.nodeKey] = currentTask } if len(tasks) == 0 { // all tasks' pre-handler failed return nil } var syncTask *task if t.num == 0 && (len(tasks) == 1 || t.needAll) && t.cancelCh == nil /*if graph can be interrupted by user, shouldn't sync run task*/ { syncTask = tasks[0] tasks = tasks[1:] } for _, currentTask := range tasks { t.num += 1 go t.execute(currentTask) } if syncTask != nil { t.num += 1 t.execute(syncTask) } return nil } func (t *taskManager) wait() (tasks []*task, canceled bool, canceledTasks []*task) { if t.needAll { tasks, canceledTasks = t.waitAll() return tasks, t.canceled, canceledTasks } ta, success, canceled := t.waitOne() if canceled { // has canceled and timeout, return canceled tasks for _, rta := range t.runningTasks { canceledTasks = append(canceledTasks, rta) } t.runningTasks = make(map[string]*task) t.num = 0 return nil, true, canceledTasks } if t.canceled { // has canceled, but not timeout, wait all tasks, canceledTasks = t.waitAll() return append(tasks, ta), true, canceledTasks } if !success { return []*task{}, t.canceled, nil } return []*task{ta}, t.canceled, nil } func (t *taskManager) waitOne() (ta *task, success bool, canceled bool) { if t.num == 0 { return nil, false, false } if t.cancelCh == nil { ta, _ = t.done.Receive() } else { ta, _, canceled = t.receive(t.done.Receive) } t.num-- if canceled { return nil, false, true } delete(t.runningTasks, ta.nodeKey) if ta.originalInput != nil && (ta.err == nil || !isInterruptError(ta.err)) { if sr, ok := ta.originalInput.(streamReader); ok { sr.close() } ta.originalInput = nil } if ta.err != nil { // biz error, jump post processor return ta, true, false } runPostHandler(ta, t.runWrapper) return ta, true, false } func (t *taskManager) waitAll() (successTasks []*task, canceledTasks []*task) { result := make([]*task, 0, t.num) for { ta, success, canceled := t.waitOne() if canceled { for _, rt := range t.runningTasks { canceledTasks = append(canceledTasks, rt) } t.runningTasks = make(map[string]*task) t.num = 0 return result, canceledTasks } if !success { return result, nil } result = append(result, ta) } } func (t *taskManager) receive(recv func() (*task, bool)) (ta *task, closed bool, canceled bool) { if t.deadline != nil { // have canceled, receive in a certain time return receiveWithDeadline(recv, *t.deadline) } if t.canceled { // canceled without timeout ta, closed = recv() return ta, closed, false } if t.cancelCh != nil { // have not canceled, receive while listening ta, closed, canceled, t.canceled, t.deadline = receiveWithListening(recv, t.cancelCh) return ta, closed, canceled } // won't cancel ta, closed = recv() return ta, closed, false } func receiveWithDeadline(recv func() (*task, bool), deadline time.Time) (ta *task, closed bool, canceled bool) { now := time.Now() if deadline.Before(now) { return nil, false, true } timeout := deadline.Sub(now) resultCh := make(chan struct{}, 1) go func() { ta, closed = recv() resultCh <- struct{}{} }() timeoutCh := time.After(timeout) select { case <-resultCh: return ta, closed, false case <-timeoutCh: return nil, false, true } } func receiveWithListening(recv func() (*task, bool), cancel chan *time.Duration) (*task, bool, bool, bool, *time.Time) { type pair struct { ta *task closed bool } resultCh := make(chan pair, 1) var timeoutCh <-chan time.Time var deadline *time.Time canceled := false go func() { ta, closed := recv() resultCh <- pair{ta, closed} }() select { case p := <-resultCh: return p.ta, p.closed, false, false, nil case timeout, ok := <-cancel: if !ok { // unreachable break } canceled = true if timeout == nil { // canceled without timeout break } timeoutCh = time.After(*timeout) dt := time.Now().Add(*timeout) deadline = &dt } if timeoutCh != nil { select { case p := <-resultCh: return p.ta, p.closed, false, canceled, deadline case <-timeoutCh: return nil, false, true, canceled, deadline } } p := <-resultCh return p.ta, p.closed, false, canceled, nil } func runPreHandler(ta *task, runWrapper runnableCallWrapper) (err error) { defer func() { if e := recover(); e != nil { err = safe.NewPanicErr(fmt.Errorf("panic in pre handler: %v", e), debug.Stack()) } }() if ta.call.preProcessor != nil && !ta.skipPreHandler { nInput, err := runWrapper(ta.ctx, ta.call.preProcessor, ta.input, ta.option...) if err != nil { return fmt.Errorf("run node[%s] pre processor fail: %w", ta.nodeKey, err) } ta.input = nInput } return nil } func runPostHandler(ta *task, runWrapper runnableCallWrapper) { defer func() { if e := recover(); e != nil { ta.err = safe.NewPanicErr(fmt.Errorf("panic in post handler: %v", e), debug.Stack()) } }() if ta.call.postProcessor != nil { nOutput, err := runWrapper(ta.ctx, ta.call.postProcessor, ta.output, ta.option...) if err != nil { ta.err = fmt.Errorf("run node[%s] post processor fail: %w", ta.nodeKey, err) } ta.output = nOutput } } ================================================ FILE: compose/graph_node.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "reflect" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/generic" ) // the info of most original executable object directly provided by the user type executorMeta struct { // automatically identified based on the way of addNode component component // indicates whether the executable object user provided could execute the callback aspect itself. // if it could, the callback in the corresponding graph node won't be executed // for components, the value comes from callbacks.Checker isComponentCallbackEnabled bool // for components, the value comes from components.Typer // for lambda, the value comes from the user's explicit config // if componentImplType is empty, then the class name or func name in the instance will be inferred, but no guarantee. componentImplType string } type nodeInfo struct { // the name of graph node for display purposes, not unique. // passed from WithNodeName() name string inputKey string outputKey string preProcessor, postProcessor *composableRunnable compileOption *graphCompileOptions // if the node is an AnyGraph, it will need compile options of its own } // graphNode the complete information of the node in graph type graphNode struct { cr *composableRunnable g AnyGraph nodeInfo *nodeInfo executorMeta *executorMeta instance any opts []GraphAddNodeOpt } func (gn *graphNode) getGenericHelper() *genericHelper { var ret *genericHelper if gn.g != nil { ret = gn.g.getGenericHelper() } else if gn.cr != nil { ret = gn.cr.genericHelper } else { return nil } if gn.nodeInfo != nil { if len(gn.nodeInfo.inputKey) > 0 { ret = ret.forMapInput() } if len(gn.nodeInfo.outputKey) > 0 { ret = ret.forMapOutput() } } return ret } func (gn *graphNode) inputType() reflect.Type { if gn.nodeInfo != nil && len(gn.nodeInfo.inputKey) != 0 { return generic.TypeOf[map[string]any]() } // priority follow compile if gn.g != nil { return gn.g.inputType() } else if gn.cr != nil { return gn.cr.inputType } return nil } func (gn *graphNode) outputType() reflect.Type { if gn.nodeInfo != nil && len(gn.nodeInfo.outputKey) != 0 { return generic.TypeOf[map[string]any]() } // priority follow compile if gn.g != nil { return gn.g.outputType() } else if gn.cr != nil { return gn.cr.outputType } return nil } func (gn *graphNode) compileIfNeeded(ctx context.Context) (*composableRunnable, error) { var r *composableRunnable if gn.g != nil { cr, err := gn.g.compile(ctx, gn.nodeInfo.compileOption) if err != nil { return nil, err } r = cr gn.cr = cr } else if gn.cr != nil { r = gn.cr } else { return nil, errors.New("no graph or component provided") } r.meta = gn.executorMeta r.nodeInfo = gn.nodeInfo if gn.nodeInfo.outputKey != "" { r = outputKeyedComposableRunnable(gn.nodeInfo.outputKey, r) } if gn.nodeInfo.inputKey != "" { r = inputKeyedComposableRunnable(gn.nodeInfo.inputKey, r) } return r, nil } func parseExecutorInfoFromComponent(c component, executor any) *executorMeta { componentImplType, ok := components.GetType(executor) if !ok { componentImplType = generic.ParseTypeName(reflect.ValueOf(executor)) } return &executorMeta{ component: c, isComponentCallbackEnabled: components.IsCallbacksEnabled(executor), componentImplType: componentImplType, } } func getNodeInfo(opts ...GraphAddNodeOpt) (*nodeInfo, *graphAddNodeOpts) { opt := getGraphAddNodeOpts(opts...) return &nodeInfo{ name: opt.nodeOptions.nodeName, inputKey: opt.nodeOptions.inputKey, outputKey: opt.nodeOptions.outputKey, preProcessor: opt.processor.statePreHandler, postProcessor: opt.processor.statePostHandler, compileOption: newGraphCompileOptions(opt.nodeOptions.graphCompileOption...), }, opt } ================================================ FILE: compose/graph_run.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "reflect" "strings" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/internal/serialization" ) type chanCall struct { action *composableRunnable writeTo []string writeToBranches []*GraphBranch controls []string // branch must control preProcessor, postProcessor *composableRunnable } type chanBuilder func(dependencies []string, indirectDependencies []string, zeroValue func() any, emptyStream func() streamReader) channel type runner struct { chanSubscribeTo map[string]*chanCall successors map[string][]string dataPredecessors map[string][]string controlPredecessors map[string][]string inputChannels *chanCall chanBuilder chanBuilder // could be nil eager bool dag bool runCtx func(ctx context.Context) context.Context options graphCompileOptions inputType reflect.Type outputType reflect.Type // take effect as a subgraph through toComposableRunnable inputStreamFilter streamMapFilter inputConverter handlerPair inputFieldMappingConverter handlerPair inputConvertStreamPair, outputConvertStreamPair streamConvertPair *genericHelper // checks need to do because cannot check at compile runtimeCheckEdges map[string]map[string]bool runtimeCheckBranches map[string][]bool edgeHandlerManager *edgeHandlerManager preNodeHandlerManager *preNodeHandlerManager preBranchHandlerManager *preBranchHandlerManager checkPointer *checkPointer interruptBeforeNodes []string interruptAfterNodes []string mergeConfigs map[string]FanInMergeConfig } func (r *runner) invoke(ctx context.Context, input any, opts ...Option) (any, error) { return r.run(ctx, false, input, opts...) } func (r *runner) transform(ctx context.Context, input streamReader, opts ...Option) (streamReader, error) { s, err := r.run(ctx, true, input, opts...) if err != nil { return nil, err } return s.(streamReader), nil } type runnableCallWrapper func(context.Context, *composableRunnable, any, ...any) (any, error) func runnableInvoke(ctx context.Context, r *composableRunnable, input any, opts ...any) (any, error) { return r.i(ctx, input, opts...) } func runnableTransform(ctx context.Context, r *composableRunnable, input any, opts ...any) (any, error) { return r.t(ctx, input.(streamReader), opts...) } func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Option) (result any, err error) { haveOnStart := false // delay triggering onGraphStart until state initialization is complete, so that the state can be accessed within onGraphStart. defer func() { if !haveOnStart { ctx, input = onGraphStart(ctx, input, isStream) } if err != nil { ctx, err = onGraphError(ctx, err) } else { ctx, result = onGraphEnd(ctx, result, isStream) } }() var runWrapper runnableCallWrapper runWrapper = runnableInvoke if isStream { runWrapper = runnableTransform } // Initialize channel and task managers. cm := r.initChannelManager(isStream) tm := r.initTaskManager(runWrapper, getGraphCancel(ctx), opts...) maxSteps := r.options.maxRunSteps maxSteps, err = r.resolveMaxSteps(maxSteps, opts) if err != nil { return nil, err } // Extract and validate options for each node. optMap, extractErr := extractOption(r.chanSubscribeTo, opts...) if extractErr != nil { return nil, newGraphRunError(fmt.Errorf("graph extract option fail: %w", extractErr)) } // Extract CheckPointID checkPointID, writeToCheckPointID, stateModifier, forceNewRun := getCheckPointInfo(opts...) if checkPointID != nil && r.checkPointer.store == nil { return nil, newGraphRunError(fmt.Errorf("receive checkpoint id but have not set checkpoint store")) } // Extract subgraph path, isSubGraph := getNodePath(ctx) // load checkpoint from ctx/store or init graph initialized := false var nextTasks []*task if cp := getCheckPointFromCtx(ctx); cp != nil { // in subgraph, try to load checkpoint from ctx initialized = true ctx, err = r.restoreCheckPointState(ctx, *path, getStateModifier(ctx), cp, isStream, cm) if err != nil { return nil, err } ctx, input = onGraphStart(ctx, input, isStream) haveOnStart = true nextTasks, err = r.restoreTasks(ctx, cp.Inputs, cp.SkipPreHandler, cp.RerunNodes, isStream, optMap) if err != nil { return nil, newGraphRunError(fmt.Errorf("restore tasks fail: %w", err)) } } else if checkPointID != nil && !forceNewRun { cp, err = getCheckPointFromStore(ctx, *checkPointID, r.checkPointer) if err != nil { return nil, newGraphRunError(fmt.Errorf("load checkpoint from store fail: %w", err)) } if cp != nil { // load checkpoint from store initialized = true ctx = setStateModifier(ctx, stateModifier) ctx = setCheckPointToCtx(ctx, cp) ctx, err = r.restoreCheckPointState(ctx, *NewNodePath(), stateModifier, cp, isStream, cm) if err != nil { return nil, err } ctx, input = onGraphStart(ctx, input, isStream) haveOnStart = true nextTasks, err = r.restoreTasks(ctx, cp.Inputs, cp.SkipPreHandler, cp.RerunNodes, isStream, optMap) if err != nil { return nil, newGraphRunError(fmt.Errorf("restore tasks fail: %w", err)) } } } if !initialized { // have not inited from checkpoint if r.runCtx != nil { ctx = r.runCtx(ctx) } ctx, input = onGraphStart(ctx, input, isStream) haveOnStart = true var isEnd bool nextTasks, result, isEnd, err = r.calculateNextTasks(ctx, []*task{{ nodeKey: START, call: r.inputChannels, output: input, }}, isStream, cm, optMap) if err != nil { return nil, newGraphRunError(fmt.Errorf("calculate next tasks fail: %w", err)) } if isEnd { return result, nil } if len(nextTasks) == 0 { return nil, newGraphRunError(fmt.Errorf("no tasks to execute after graph start")) } if keys := getHitKey(nextTasks, r.interruptBeforeNodes); len(keys) > 0 { tempInfo := newInterruptTempInfo() tempInfo.interruptBeforeNodes = append(tempInfo.interruptBeforeNodes, keys...) return nil, r.handleInterrupt(ctx, tempInfo, nextTasks, cm.channels, isStream, isSubGraph, writeToCheckPointID, ) } } // used to reporting NoTask error var lastCompletedTask []*task // Main execution loop. for step := 0; ; step++ { // Check for context cancellation. select { case <-ctx.Done(): _, _ = tm.waitAll() return nil, newGraphRunError(fmt.Errorf("context has been canceled: %w", ctx.Err())) default: } if !r.dag && step >= maxSteps { return nil, newGraphRunError(ErrExceedMaxSteps) } // 1. submit next tasks // 2. get completed tasks // 3. calculate next tasks err = tm.submit(nextTasks) if err != nil { return nil, newGraphRunError(fmt.Errorf("failed to submit tasks: %w", err)) } var totalCanceledTasks []*task completedTasks, canceled, canceledTasks := tm.wait() totalCanceledTasks = append(totalCanceledTasks, canceledTasks...) tempInfo := newInterruptTempInfo() tempInfo.collectCanceledInfo(canceled, canceledTasks, completedTasks) err = r.resolveInterruptCompletedTasks(tempInfo, completedTasks) if err != nil { return nil, err // err has been wrapped } if len(tempInfo.subGraphInterrupts)+len(tempInfo.interruptRerunNodes) > 0 { var newCompletedTasks []*task newCompletedTasks, canceledTasks = tm.waitAll() totalCanceledTasks = append(totalCanceledTasks, canceledTasks...) for _, ct := range canceledTasks { // handle timeout tasks as rerun tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, ct.nodeKey) } err = r.resolveInterruptCompletedTasks(tempInfo, newCompletedTasks) if err != nil { return nil, err // err has been wrapped } // subgraph has interrupted // save other completed tasks to channel // save interrupted subgraph as next task with SkipPreHandler // report current graph interrupt info return nil, r.handleInterruptWithSubGraphAndRerunNodes( ctx, tempInfo, append(append(completedTasks, newCompletedTasks...), totalCanceledTasks...), // canceled tasks are handled as rerun writeToCheckPointID, isSubGraph, cm, isStream, ) } if len(completedTasks) == 0 { return nil, newGraphRunError(fmt.Errorf("no tasks to execute, last completed nodes: %v", printTask(lastCompletedTask))) } lastCompletedTask = completedTasks var isEnd bool nextTasks, result, isEnd, err = r.calculateNextTasks(ctx, completedTasks, isStream, cm, optMap) if err != nil { return nil, newGraphRunError(fmt.Errorf("failed to calculate next tasks: %w", err)) } if isEnd { return result, nil } tempInfo.interruptBeforeNodes = getHitKey(nextTasks, r.interruptBeforeNodes) if len(tempInfo.interruptBeforeNodes) > 0 || len(tempInfo.interruptAfterNodes) > 0 { var newCompletedTasks []*task newCompletedTasks, canceledTasks = tm.waitAll() totalCanceledTasks = append(totalCanceledTasks, canceledTasks...) for _, ct := range canceledTasks { tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, ct.nodeKey) } err = r.resolveInterruptCompletedTasks(tempInfo, newCompletedTasks) if err != nil { return nil, err // err has been wrapped } if len(tempInfo.subGraphInterrupts)+len(tempInfo.interruptRerunNodes) > 0 { return nil, r.handleInterruptWithSubGraphAndRerunNodes( ctx, tempInfo, append(append(completedTasks, newCompletedTasks...), totalCanceledTasks...), writeToCheckPointID, isSubGraph, cm, isStream, ) } var newNextTasks []*task newNextTasks, result, isEnd, err = r.calculateNextTasks(ctx, newCompletedTasks, isStream, cm, optMap) if err != nil { return nil, newGraphRunError(fmt.Errorf("failed to calculate next tasks: %w", err)) } if isEnd { return result, nil } tempInfo.interruptBeforeNodes = append(tempInfo.interruptBeforeNodes, getHitKey(newNextTasks, r.interruptBeforeNodes)...) // simple interrupt return nil, r.handleInterrupt(ctx, tempInfo, append(nextTasks, newNextTasks...), cm.channels, isStream, isSubGraph, writeToCheckPointID) } } } func (r *runner) resolveMaxSteps(maxSteps int, opts []Option) (int, error) { if r.dag { for i := range opts { if opts[i].maxRunSteps > 0 { return 0, newGraphRunError(fmt.Errorf("cannot set max run steps in dag")) } } return maxSteps, nil } for i := range opts { if opts[i].maxRunSteps > 0 { maxSteps = opts[i].maxRunSteps } } if maxSteps < 1 { return 0, newGraphRunError(errors.New("max run steps limit must be at least 1")) } return maxSteps, nil } func (r *runner) restoreCheckPointState( ctx context.Context, path NodePath, sm StateModifier, cp *checkpoint, isStream bool, cm *channelManager, ) (context.Context, error) { err := r.checkPointer.restoreCheckPoint(cp, isStream) if err != nil { return ctx, newGraphRunError(fmt.Errorf("restore checkpoint fail: %w", err)) } err = cm.loadChannels(cp.Channels) if err != nil { return ctx, newGraphRunError(err) } if sm != nil && cp.State != nil { err = sm(ctx, path, cp.State) if err != nil { return ctx, newGraphRunError(fmt.Errorf("state modifier fail: %w", err)) } } if cp.State != nil { isResumeTarget, hasData, data := GetResumeContext[any](ctx) if isResumeTarget && hasData { cp.State = data } var parent *internalState if prev := ctx.Value(stateKey{}); prev != nil { if p, ok := prev.(*internalState); ok { parent = p } } ctx = context.WithValue(ctx, stateKey{}, &internalState{state: cp.State, parent: parent}) } return ctx, nil } func newInterruptTempInfo() *interruptTempInfo { return &interruptTempInfo{ subGraphInterrupts: map[string]*subGraphInterruptError{}, interruptRerunExtra: map[string]any{}, } } type interruptTempInfo struct { subGraphInterrupts map[string]*subGraphInterruptError interruptRerunNodes []string interruptBeforeNodes []string interruptAfterNodes []string interruptRerunExtra map[string]any signals []*core.InterruptSignal } func (ti *interruptTempInfo) collectCanceledInfo(canceled bool, canceledTasks, completedTasks []*task) { if !canceled { return } if len(canceledTasks) > 0 { for _, t := range canceledTasks { ti.interruptRerunNodes = append(ti.interruptRerunNodes, t.nodeKey) } } else { for _, t := range completedTasks { ti.interruptAfterNodes = append(ti.interruptAfterNodes, t.nodeKey) } } } func (r *runner) resolveInterruptCompletedTasks(tempInfo *interruptTempInfo, completedTasks []*task) (err error) { for _, completedTask := range completedTasks { if completedTask.err != nil { if info := isSubGraphInterrupt(completedTask.err); info != nil { tempInfo.subGraphInterrupts[completedTask.nodeKey] = info tempInfo.signals = append(tempInfo.signals, info.signal) continue } ire := &core.InterruptSignal{} if errors.As(completedTask.err, &ire) { tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, completedTask.nodeKey) if ire.Info != nil { tempInfo.interruptRerunExtra[completedTask.nodeKey] = ire.InterruptInfo.Info } tempInfo.signals = append(tempInfo.signals, ire) continue } return wrapGraphNodeError(completedTask.nodeKey, completedTask.err) } for _, key := range r.interruptAfterNodes { if key == completedTask.nodeKey { tempInfo.interruptAfterNodes = append(tempInfo.interruptAfterNodes, key) break } } } return nil } func getHitKey(tasks []*task, keys []string) []string { var ret []string for _, t := range tasks { for _, key := range keys { if key == t.nodeKey { ret = append(ret, t.nodeKey) } } } return ret } func (r *runner) handleInterrupt( ctx context.Context, tempInfo *interruptTempInfo, nextTasks []*task, channels map[string]channel, isStream bool, isSubGraph bool, checkPointID *string, ) error { cp := &checkpoint{ Channels: channels, Inputs: make(map[string]any), SkipPreHandler: map[string]bool{}, } if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { cp.State = state.state } } intInfo := &InterruptInfo{ State: cp.State, AfterNodes: tempInfo.interruptAfterNodes, BeforeNodes: tempInfo.interruptBeforeNodes, RerunNodes: tempInfo.interruptRerunNodes, RerunNodesExtra: tempInfo.interruptRerunExtra, SubGraphs: make(map[string]*InterruptInfo), } var info any if cp.State != nil { copiedState, err := deepCopyState(cp.State) if err != nil { return fmt.Errorf("failed to copy state: %w", err) } info = copiedState } is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { return fmt.Errorf("failed to interrupt: %w", err) } cp.InterruptID2Addr, cp.InterruptID2State = core.SignalToPersistenceMaps(is) for _, t := range nextTasks { cp.Inputs[t.nodeKey] = t.input } err = r.checkPointer.convertCheckPoint(cp, isStream) if err != nil { return fmt.Errorf("failed to convert checkpoint: %w", err) } if isSubGraph { return &subGraphInterruptError{ Info: intInfo, CheckPoint: cp, signal: is, } } else if checkPointID != nil { err := r.checkPointer.set(ctx, *checkPointID, cp) if err != nil { return fmt.Errorf("failed to set checkpoint: %w, checkPointID: %s", err, *checkPointID) } } intInfo.InterruptContexts = core.ToInterruptContexts(is, nil) return &interruptError{Info: intInfo} } // deepCopyState creates a deep copy of the state using serialization func deepCopyState(state any) (any, error) { if state == nil { return nil, nil } serializer := &serialization.InternalSerializer{} data, err := serializer.Marshal(state) if err != nil { return nil, fmt.Errorf("failed to marshal state: %w", err) } // Create new instance of the same type stateType := reflect.TypeOf(state) if stateType.Kind() == reflect.Ptr { stateType = stateType.Elem() } newState := reflect.New(stateType).Interface() if err := serializer.Unmarshal(data, newState); err != nil { return nil, fmt.Errorf("failed to unmarshal state: %w", err) } return newState, nil } func (r *runner) handleInterruptWithSubGraphAndRerunNodes( ctx context.Context, tempInfo *interruptTempInfo, completeTasks []*task, checkPointID *string, isSubGraph bool, cm *channelManager, isStream bool, ) error { var rerunTasks, subgraphTasks, otherTasks []*task skipPreHandler := map[string]bool{} for _, t := range completeTasks { if _, ok := tempInfo.subGraphInterrupts[t.nodeKey]; ok { subgraphTasks = append(subgraphTasks, t) skipPreHandler[t.nodeKey] = true // subgraph won't run pre-handler again, but rerun nodes will continue } rerun := false for _, key := range tempInfo.interruptRerunNodes { if key == t.nodeKey { rerunTasks = append(rerunTasks, t) rerun = true break } } if !rerun { otherTasks = append(otherTasks, t) } } // forward completed tasks toValue, controls, err := r.resolveCompletedTasks(ctx, otherTasks, isStream, cm) if err != nil { return fmt.Errorf("failed to resolve completed tasks in interrupt: %w", err) } err = cm.updateValues(ctx, toValue) if err != nil { return fmt.Errorf("failed to update values in interrupt: %w", err) } err = cm.updateDependencies(ctx, controls) if err != nil { return fmt.Errorf("failed to update dependencies in interrupt: %w", err) } cp := &checkpoint{ Channels: cm.channels, Inputs: make(map[string]any), SkipPreHandler: skipPreHandler, SubGraphs: make(map[string]*checkpoint), } if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { cp.State = state.state } } intInfo := &InterruptInfo{ State: cp.State, BeforeNodes: tempInfo.interruptBeforeNodes, AfterNodes: tempInfo.interruptAfterNodes, RerunNodes: tempInfo.interruptRerunNodes, RerunNodesExtra: tempInfo.interruptRerunExtra, SubGraphs: make(map[string]*InterruptInfo), } var info any if cp.State != nil { copiedState, err_ := deepCopyState(cp.State) if err_ != nil { return fmt.Errorf("failed to copy state: %w", err_) } info = copiedState } is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { return fmt.Errorf("failed to interrupt: %w", err) } cp.InterruptID2Addr, cp.InterruptID2State = core.SignalToPersistenceMaps(is) for _, t := range subgraphTasks { cp.RerunNodes = append(cp.RerunNodes, t.nodeKey) cp.SubGraphs[t.nodeKey] = tempInfo.subGraphInterrupts[t.nodeKey].CheckPoint intInfo.SubGraphs[t.nodeKey] = tempInfo.subGraphInterrupts[t.nodeKey].Info } for _, t := range rerunTasks { cp.RerunNodes = append(cp.RerunNodes, t.nodeKey) if t.originalInput != nil { cp.Inputs[t.nodeKey] = t.originalInput } } err = r.checkPointer.convertCheckPoint(cp, isStream) if err != nil { return fmt.Errorf("failed to convert checkpoint: %w", err) } if isSubGraph { return &subGraphInterruptError{ Info: intInfo, CheckPoint: cp, signal: is, } } else if checkPointID != nil { err = r.checkPointer.set(ctx, *checkPointID, cp) if err != nil { return fmt.Errorf("failed to set checkpoint: %w, checkPointID: %s", err, *checkPointID) } } intInfo.InterruptContexts = core.ToInterruptContexts(is, nil) return &interruptError{Info: intInfo} } func (r *runner) calculateNextTasks(ctx context.Context, completedTasks []*task, isStream bool, cm *channelManager, optMap map[string][]any) ([]*task, any, bool, error) { writeChannelValues, controls, err := r.resolveCompletedTasks(ctx, completedTasks, isStream, cm) if err != nil { return nil, nil, false, err } nodeMap, err := cm.updateAndGet(ctx, writeChannelValues, controls) if err != nil { return nil, nil, false, fmt.Errorf("failed to update and get channels: %w", err) } var nextTasks []*task if len(nodeMap) > 0 { // Check if we've reached the END node. if v, ok := nodeMap[END]; ok { return nil, v, true, nil } // Create and submit the next batch of tasks. nextTasks, err = r.createTasks(ctx, nodeMap, optMap) if err != nil { return nil, nil, false, fmt.Errorf("failed to create tasks: %w", err) } } return nextTasks, nil, false, nil } func (r *runner) createTasks(ctx context.Context, nodeMap map[string]any, optMap map[string][]any) ([]*task, error) { var nextTasks []*task for nodeKey, nodeInput := range nodeMap { call, ok := r.chanSubscribeTo[nodeKey] if !ok { return nil, fmt.Errorf("node[%s] has not been registered", nodeKey) } if call.action.nodeInfo != nil && call.action.nodeInfo.compileOption != nil { ctx = forwardCheckPoint(ctx, nodeKey) } nextTasks = append(nextTasks, &task{ ctx: AppendAddressSegment(ctx, AddressSegmentNode, nodeKey), nodeKey: nodeKey, call: call, input: nodeInput, option: optMap[nodeKey], }) } return nextTasks, nil } func getCheckPointInfo(opts ...Option) (checkPointID *string, writeToCheckPointID *string, stateModifier StateModifier, forceNewRun bool) { for _, opt := range opts { if opt.checkPointID != nil { checkPointID = opt.checkPointID } if opt.writeToCheckPointID != nil { writeToCheckPointID = opt.writeToCheckPointID } if opt.stateModifier != nil { stateModifier = opt.stateModifier } forceNewRun = opt.forceNewRun } if writeToCheckPointID == nil { writeToCheckPointID = checkPointID } return } func (r *runner) restoreTasks( ctx context.Context, inputs map[string]any, skipPreHandler map[string]bool, rerunNodes []string, isStream bool, optMap map[string][]any) ([]*task, error) { ret := make([]*task, 0, len(inputs)) for _, key := range rerunNodes { if _, hasInput := inputs[key]; hasInput { continue } call, ok := r.chanSubscribeTo[key] if !ok { return nil, fmt.Errorf("channel[%s] from checkpoint is not registered", key) } if isStream { inputs[key] = call.action.inputEmptyStream() } else { inputs[key] = call.action.inputZeroValue() } } for key, input := range inputs { call, ok := r.chanSubscribeTo[key] if !ok { return nil, fmt.Errorf("channel[%s] from checkpoint is not registered", key) } if call.action.nodeInfo != nil && call.action.nodeInfo.compileOption != nil { // sub graph ctx = forwardCheckPoint(ctx, key) } newTask := &task{ ctx: AppendAddressSegment(ctx, AddressSegmentNode, key), nodeKey: key, call: call, input: input, option: nil, skipPreHandler: skipPreHandler[key], } if opt, ok := optMap[key]; ok { newTask.option = opt } ret = append(ret, newTask) } return ret, nil } func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool, cm *channelManager) (map[string]map[string]any, map[string][]string, error) { writeChannelValues := make(map[string]map[string]any) newDependencies := make(map[string][]string) for _, t := range completedTasks { for _, key := range t.call.controls { newDependencies[key] = append(newDependencies[key], t.nodeKey) } // update channel & new_next_tasks vs := copyItem(t.output, len(t.call.writeTo)+len(t.call.writeToBranches)*2) nextNodeKeys, err := r.calculateBranch(ctx, t.nodeKey, t.call, vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream, cm) if err != nil { return nil, nil, fmt.Errorf("calculate next step fail, node: %s, error: %w", t.nodeKey, err) } for _, key := range nextNodeKeys { newDependencies[key] = append(newDependencies[key], t.nodeKey) } nextNodeKeys = append(nextNodeKeys, t.call.writeTo...) // If branches generates more than one successor, the inputs need to be copied accordingly. if len(nextNodeKeys) > 0 { toCopyNum := len(nextNodeKeys) - len(t.call.writeTo) - len(t.call.writeToBranches) nVs := copyItem(vs[len(t.call.writeTo)+len(t.call.writeToBranches)-1], toCopyNum+1) vs = append(vs[:len(t.call.writeTo)+len(t.call.writeToBranches)-1], nVs...) for i, next := range nextNodeKeys { if _, ok := writeChannelValues[next]; !ok { writeChannelValues[next] = make(map[string]any) } writeChannelValues[next][t.nodeKey] = vs[i] } } } return writeChannelValues, newDependencies, nil } func (r *runner) calculateBranch(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool, cm *channelManager) ([]string, error) { if len(input) < len(startChan.writeToBranches) { // unreachable return nil, errors.New("calculate next input length is shorter than branches") } ret := make([]string, 0, len(startChan.writeToBranches)) skippedNodes := make(map[string]struct{}) for i, branch := range startChan.writeToBranches { // check branch input type if needed var err error input[i], err = r.preBranchHandlerManager.handle(curNodeKey, i, input[i], isStream) if err != nil { return nil, fmt.Errorf("branch[%s]-[%d] pre handler fail: %w", curNodeKey, branch.idx, err) } // process branch output var ws []string if isStream { ws, err = branch.collect(ctx, input[i].(streamReader)) if err != nil { return nil, fmt.Errorf("branch collect run error: %w", err) } } else { ws, err = branch.invoke(ctx, input[i]) if err != nil { return nil, fmt.Errorf("branch invoke run error: %w", err) } } for node := range branch.endNodes { skipped := true for _, w := range ws { if node == w { skipped = false break } } if skipped { skippedNodes[node] = struct{}{} } } ret = append(ret, ws...) } // When a node has multiple branches, // there may be a situation where a succeeding node is selected by some branches and discarded by the other branches, // in which case the succeeding node should not be skipped. var skippedNodeList []string for _, selected := range ret { if _, ok := skippedNodes[selected]; ok { delete(skippedNodes, selected) } } for skipped := range skippedNodes { skippedNodeList = append(skippedNodeList, skipped) } err := cm.reportBranch(curNodeKey, skippedNodeList) if err != nil { return nil, err } return ret, nil } func (r *runner) initTaskManager(runWrapper runnableCallWrapper, cancelVal *graphCancelChanVal, opts ...Option) *taskManager { tm := &taskManager{ runWrapper: runWrapper, opts: opts, needAll: !r.eager, done: internal.NewUnboundedChan[*task](), runningTasks: make(map[string]*task), persistRerunInput: cancelVal != nil, } if cancelVal != nil { tm.cancelCh = cancelVal.ch } return tm } func (r *runner) initChannelManager(isStream bool) *channelManager { builder := r.chanBuilder if builder == nil { builder = pregelChannelBuilder } chs := make(map[string]channel) for ch := range r.chanSubscribeTo { chs[ch] = builder(r.controlPredecessors[ch], r.dataPredecessors[ch], r.chanSubscribeTo[ch].action.inputZeroValue, r.chanSubscribeTo[ch].action.inputEmptyStream) } chs[END] = builder(r.controlPredecessors[END], r.dataPredecessors[END], r.outputZeroValue, r.outputEmptyStream) dataPredecessors := make(map[string]map[string]struct{}) for k, vs := range r.dataPredecessors { dataPredecessors[k] = make(map[string]struct{}) for _, v := range vs { dataPredecessors[k][v] = struct{}{} } } controlPredecessors := make(map[string]map[string]struct{}) for k, vs := range r.controlPredecessors { controlPredecessors[k] = make(map[string]struct{}) for _, v := range vs { controlPredecessors[k][v] = struct{}{} } } for k, v := range chs { if cfg, ok := r.mergeConfigs[k]; ok { v.setMergeConfig(cfg) } } return &channelManager{ isStream: isStream, channels: chs, successors: r.successors, dataPredecessors: dataPredecessors, controlPredecessors: controlPredecessors, edgeHandlerManager: r.edgeHandlerManager, preNodeHandlerManager: r.preNodeHandlerManager, } } func (r *runner) toComposableRunnable() *composableRunnable { cr := &composableRunnable{ i: func(ctx context.Context, input any, opts ...any) (output any, err error) { tos, err := convertOption[Option](opts...) if err != nil { return nil, err } return r.invoke(ctx, input, tos...) }, t: func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { tos, err := convertOption[Option](opts...) if err != nil { return nil, err } return r.transform(ctx, input, tos...) }, inputType: r.inputType, outputType: r.outputType, genericHelper: r.genericHelper, optionType: nil, // if option type is nil, graph will transmit all options. } return cr } func copyItem(item any, n int) []any { if n < 2 { return []any{item} } ret := make([]any, n) if s, ok := item.(streamReader); ok { ss := s.copy(n) for i := range ret { ret[i] = ss[i] } return ret } for i := range ret { ret[i] = item } return ret } func printTask(ts []*task) string { if len(ts) == 0 { return "[]" } sb := strings.Builder{} sb.WriteString("[") for i := 0; i < len(ts)-1; i++ { sb.WriteString(ts[i].nodeKey) sb.WriteString(", ") } sb.WriteString(ts[len(ts)-1].nodeKey) sb.WriteString("]") return sb.String() } ================================================ FILE: compose/graph_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "io" "reflect" "sort" "strconv" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/schema" ) func TestSingleGraph(t *testing.T) { const ( nodeOfModel = "model" nodeOfPrompt = "prompt" ) ctx := context.Background() g := NewGraph[map[string]any, *schema.Message]() pt := prompt.FromMessages(schema.FString, schema.UserMessage("what's the weather in {location}?"), ) err := g.AddChatTemplateNode("prompt", pt) assert.NoError(t, err) cm := &chatModel{ msgs: []*schema.Message{ { Role: schema.Assistant, Content: "the weather is good", }, }, } err = g.AddChatModelNode(nodeOfModel, cm, WithNodeName("MockChatModel")) assert.NoError(t, err) err = g.AddEdge(START, nodeOfPrompt) assert.NoError(t, err) err = g.AddEdge(nodeOfPrompt, nodeOfModel) assert.NoError(t, err) err = g.AddEdge(nodeOfModel, END) assert.NoError(t, err) r, err := g.Compile(context.Background(), WithMaxRunSteps(10)) assert.NoError(t, err) in := map[string]any{"location": "beijing"} _, err = r.Invoke(ctx, in) assert.NoError(t, err) // stream s, err := r.Stream(ctx, in) assert.NoError(t, err) _, err = concatStreamReader(s) assert.NoError(t, err) sr, sw := schema.Pipe[map[string]any](1) _ = sw.Send(in, nil) sw.Close() // transform s, err = r.Transform(ctx, sr) assert.NoError(t, err) _, err = concatStreamReader(s) assert.NoError(t, err) // error test in = map[string]any{"wrong key": 1} _, err = r.Invoke(ctx, in) assert.Errorf(t, err, "could not find key: location") _, err = r.Stream(ctx, in) assert.Errorf(t, err, "could not find key: location") sr, sw = schema.Pipe[map[string]any](1) _ = sw.Send(in, nil) sw.Close() _, err = r.Transform(ctx, sr) assert.Errorf(t, err, "could not find key: location") } type person interface { Say() string } type doctor struct { say string } func (d *doctor) Say() string { return d.say } func TestGraphWithImplementableType(t *testing.T) { const ( node1 = "1st" node2 = "2nd" ) ctx := context.Background() g := NewGraph[string, string]() err := g.AddLambdaNode(node1, InvokableLambda(func(ctx context.Context, input string) (output *doctor, err error) { return &doctor{say: input}, nil })) assert.NoError(t, err) err = g.AddLambdaNode(node2, InvokableLambda(func(ctx context.Context, input person) (output string, err error) { return input.Say(), nil })) assert.NoError(t, err) err = g.AddEdge(START, node1) assert.NoError(t, err) err = g.AddEdge(node1, node2) assert.NoError(t, err) err = g.AddEdge(node2, END) assert.NoError(t, err) r, err := g.Compile(context.Background(), WithMaxRunSteps(10)) assert.NoError(t, err) _, err = r.Invoke(ctx, "how are you", WithRuntimeMaxSteps(1)) assert.Error(t, err) assert.ErrorContains(t, err, "exceeds max steps") _, err = r.Invoke(ctx, "how are you", WithRuntimeMaxSteps(1)) assert.Error(t, err) assert.ErrorContains(t, err, "exceeds max steps") out, err := r.Invoke(ctx, "how are you") assert.NoError(t, err) assert.Equal(t, "how are you", out) outStream, err := r.Stream(ctx, "i'm fine") assert.NoError(t, err) defer outStream.Close() say, err := outStream.Recv() assert.NoError(t, err) assert.Equal(t, "i'm fine", say) } func TestNestedGraph(t *testing.T) { const ( nodeOfLambda1 = "lambda1" nodeOfLambda2 = "lambda2" nodeOfSubGraph = "sub_graph" nodeOfModel = "model" nodeOfPrompt = "prompt" ) ctx := context.Background() g := NewGraph[string, *schema.Message]() sg := NewGraph[map[string]any, *schema.Message]() l1 := InvokableLambda[string, map[string]any]( func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"location": input}, nil }) l2 := InvokableLambda[*schema.Message, *schema.Message]( func(ctx context.Context, input *schema.Message) (output *schema.Message, err error) { input.Content = fmt.Sprintf("after lambda 2: %s", input.Content) return input, nil }) pt := prompt.FromMessages(schema.FString, schema.UserMessage("what's the weather in {location}?"), ) err := sg.AddChatTemplateNode("prompt", pt) assert.NoError(t, err) cm := &chatModel{ msgs: []*schema.Message{ { Role: schema.Assistant, Content: "the weather is good", }, }, } err = sg.AddChatModelNode(nodeOfModel, cm, WithNodeName("MockChatModel")) assert.NoError(t, err) err = sg.AddEdge(START, nodeOfPrompt) assert.NoError(t, err) err = sg.AddEdge(nodeOfPrompt, nodeOfModel) assert.NoError(t, err) err = sg.AddEdge(nodeOfModel, END) assert.NoError(t, err) err = g.AddLambdaNode(nodeOfLambda1, l1, WithNodeName("Lambda1")) assert.NoError(t, err) err = g.AddGraphNode(nodeOfSubGraph, sg, WithNodeName("SubGraphName")) assert.NoError(t, err) err = g.AddLambdaNode(nodeOfLambda2, l2, WithNodeName("Lambda2")) assert.NoError(t, err) err = g.AddEdge(START, nodeOfLambda1) assert.NoError(t, err) err = g.AddEdge(nodeOfLambda1, nodeOfSubGraph) assert.NoError(t, err) err = g.AddEdge(nodeOfSubGraph, nodeOfLambda2) assert.NoError(t, err) err = g.AddEdge(nodeOfLambda2, END) assert.NoError(t, err) r, err := g.Compile(context.Background(), WithMaxRunSteps(10), WithGraphName("GraphName"), ) assert.NoError(t, err) ck := "depth" cb := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { v, ok := ctx.Value(ck).(int) if ok { v++ } return context.WithValue(ctx, ck, v) }). OnStartWithStreamInputFn(func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { input.Close() v, ok := ctx.Value(ck).(int) if ok { v++ } return context.WithValue(ctx, ck, v) }).Build() // invoke _, err = r.Invoke(ctx, "london", WithCallbacks(cb)) assert.NoError(t, err) // stream rs, err := r.Stream(ctx, "london", WithCallbacks(cb)) assert.NoError(t, err) for { _, err = rs.Recv() if err == io.EOF { break } assert.NoError(t, err) } // collect sr, sw := schema.Pipe[string](5) _ = sw.Send("london", nil) sw.Close() _, err = r.Collect(ctx, sr, WithCallbacks(cb)) assert.NoError(t, err) // transform sr, sw = schema.Pipe[string](5) _ = sw.Send("london", nil) sw.Close() rt, err := r.Transform(ctx, sr, WithCallbacks(cb)) assert.NoError(t, err) for { _, err = rt.Recv() if err == io.EOF { break } assert.NoError(t, err) } } type chatModel struct { msgs []*schema.Message } func (c *chatModel) BindTools(tools []*schema.ToolInfo) error { return nil } func (c *chatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { return c.msgs[0], nil } func (c *chatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](len(c.msgs)) go func() { for _, msg := range c.msgs { sw.Send(msg, nil) } sw.Close() }() return sr, nil } func TestValidate(t *testing.T) { // test unmatched nodes g := NewGraph[string, string]() err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil })) assert.NoError(t, err) err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input int) (output string, err error) { return "", nil })) assert.NoError(t, err) err = g.AddEdge("1", "2") assert.ErrorContains(t, err, "graph edge[1]-[2]: start node's output type[string] and end node's input type[int] mismatch") // test unmatched passthrough node g = NewGraph[string, string]() err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil })) assert.NoError(t, err) err = g.AddPassthroughNode("2") assert.NoError(t, err) err = g.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input int) (output string, err error) { return "", nil })) assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge("2", "3") assert.ErrorContains(t, err, "graph edge[2]-[3]: start node's output type[string] and end node's input type[int] mismatch") // test may matched passthrough g2 := NewGraph[any, string]() err = g2.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input any) (output any, err error) { return input, nil })) assert.NoError(t, err) err = g2.AddPassthroughNode("2") assert.NoError(t, err) err = g2.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input int) (output string, err error) { return strconv.Itoa(input), nil })) assert.NoError(t, err) err = g2.AddEdge(START, "1") assert.NoError(t, err) err = g2.AddEdge("2", "3") assert.NoError(t, err) err = g2.AddEdge("1", "2") assert.NoError(t, err) err = g2.AddEdge("3", END) assert.NoError(t, err) ru, err := g2.Compile(context.Background()) assert.NoError(t, err) // success result, err := ru.Invoke(context.Background(), 1) assert.NoError(t, err) assert.Equal(t, result, "1") // fail _, err = ru.Invoke(context.Background(), "1") assert.ErrorContains(t, err, "runtime type check") // test unmatched graph type g = NewGraph[string, string]() err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input int) (output string, err error) { return "", nil })) assert.NoError(t, err) err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output int, err error) { return 0, nil })) assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge(START, "1") assert.ErrorContains(t, err, "graph edge[start]-[1]: start node's output type[string] and end node's input type[int] mismatch") // sub graph implement type A interface { A() } type B interface { B() } type AB interface { A B } lA := InvokableLambda(func(ctx context.Context, input A) (output string, err error) { return "", nil }) lB := InvokableLambda(func(ctx context.Context, input B) (output string, err error) { return "", nil }) lAB := InvokableLambda(func(ctx context.Context, input string) (output AB, err error) { return nil, nil }) p := NewParallel().AddLambda("1", lA).AddLambda("2", lB) c := NewChain[string, map[string]any]().AppendLambda(lAB).AppendParallel(p) _, err = c.Compile(context.Background()) assert.NoError(t, err) // error usage p = NewParallel().AddLambda("1", lA).AddLambda("2", lAB) c = NewChain[string, map[string]any]().AppendParallel(p) _, err = c.Compile(context.Background()) assert.ErrorContains(t, err, "add parallel edge failed, from=start, to=node_0_parallel_0, err: graph edge[start]-[node_0_parallel_0]: start node's output type[string] and end node's input type[compose.A] mismatch") // test graph output type check gg := NewGraph[string, A]() err = gg.AddLambdaNode("nodeA", InvokableLambda(func(ctx context.Context, input string) (output A, err error) { return nil, nil })) assert.NoError(t, err) err = gg.AddLambdaNode("nodeA2", InvokableLambda(func(ctx context.Context, input string) (output A, err error) { return nil, nil })) assert.NoError(t, err) err = gg.AddLambdaNode("nodeB", InvokableLambda(func(ctx context.Context, input string) (output B, err error) { return nil, nil })) assert.NoError(t, err) err = gg.AddEdge("nodeA", END) assert.NoError(t, err) err = gg.AddEdge("nodeB", END) assert.ErrorContains(t, err, "graph edge[nodeB]-[end]: start node's output type[compose.B] and end node's input type[compose.A] mismatch") err = gg.AddEdge("nodeA2", END) assert.ErrorContains(t, err, "graph edge[nodeB]-[end]: start node's output type[compose.B] and end node's input type[compose.A] mismatch") // test any type anyG := NewGraph[any, string]() err = anyG.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) assert.NoError(t, err) err = anyG.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) assert.NoError(t, err) err = anyG.AddEdge(START, "node1") assert.NoError(t, err) err = anyG.AddEdge("node1", "node2") assert.NoError(t, err) err = anyG.AddEdge("node2", END) if err != nil { t.Fatal(err) } r, err := anyG.Compile(context.Background()) assert.NoError(t, err) result, err = r.Invoke(context.Background(), "start") assert.NoError(t, err) assert.Equal(t, "startnode1node2", result) streamResult, err := r.Stream(context.Background(), "start") assert.NoError(t, err) result = "" for { chunk, err := streamResult.Recv() if err != nil { if err == io.EOF { break } assert.NoError(t, err) } result += chunk } assert.Equal(t, "startnode1node2", result) // test any type runtime error anyG = NewGraph[any, string]() err = anyG.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return 123, nil })) if err != nil { t.Fatal(err) } err = anyG.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) if err != nil { t.Fatal(err) } err = anyG.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = anyG.AddEdge("node1", "node2") if err != nil { t.Fatal(err) } err = anyG.AddEdge("node2", END) if err != nil { t.Fatal(err) } r, err = anyG.Compile(context.Background()) if err != nil { t.Fatal(err) } _, err = r.Invoke(context.Background(), "start") if err == nil || !strings.Contains(err.Error(), "runtime") { t.Fatal("test any type runtime error fail, error is nil or error doesn't contain key word runtime") } _, err = r.Stream(context.Background(), "start") if err == nil || !strings.Contains(err.Error(), "runtime") { t.Fatal("test any type runtime error fail, error is nil or error doesn't contain key word runtime") } // test branch any type // success g = NewGraph[string, string]() err = g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node3", nil })) if err != nil { t.Fatal(err) } err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { return "node2", nil }, map[string]bool{"node2": true, "node3": true})) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = g.AddEdge("node2", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node3", END) if err != nil { t.Fatal(err) } rr, err := g.Compile(context.Background()) if err != nil { t.Fatal(err) } ret, err := rr.Invoke(context.Background(), "start") if err != nil { t.Fatal(err) } if ret != "startnode1node2" { t.Fatal("test branch any type fail, result is unexpected") } streamResult, err = rr.Stream(context.Background(), "start") if err != nil { t.Fatal(err) } ret, err = concatStreamReader(streamResult) if err != nil { t.Fatal(err) } if ret != "startnode1node2" { t.Fatal("test branch any type fail, result is unexpected") } // fail g = NewGraph[string, string]() err = g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return 1 /*error type*/, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node3", nil })) if err != nil { t.Fatal(err) } err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { return "node2", nil }, map[string]bool{"node2": true, "node3": true})) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = g.AddEdge("node2", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node3", END) if err != nil { t.Fatal(err) } rr, err = g.Compile(context.Background()) if err != nil { t.Fatal(err) } _, err = rr.Invoke(context.Background(), "start") if err == nil || !strings.Contains(err.Error(), "runtime") { t.Fatal("test branch any type fail, haven't report runtime error") } _, err = rr.Stream(context.Background(), "start") if err == nil || !strings.Contains(err.Error(), "runtime") { t.Fatal("test branch any type fail, haven't report runtime error") } } func TestValidateMultiAnyValueBranch(t *testing.T) { // success g := NewGraph[string, map[string]any]() err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node2": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node3": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node4", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node4": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node5", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node5": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { return "node2", nil }, map[string]bool{"node2": true, "node3": true})) if err != nil { t.Fatal(err) } err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { return "node4", nil }, map[string]bool{"node4": true, "node5": true})) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = g.AddEdge("node2", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node3", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node4", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node5", END) if err != nil { t.Fatal(err) } rr, err := g.Compile(context.Background()) if err != nil { t.Fatal(err) } ret, err := rr.Invoke(context.Background(), "start") if err != nil { t.Fatal(err) } if !ret["node2"].(bool) || !ret["node4"].(bool) { t.Fatal("test branch any type fail, result is unexpected") } streamResult, err := rr.Stream(context.Background(), "start") if err != nil { t.Fatal(err) } ret, err = concatStreamReader(streamResult) if err != nil { t.Fatal(err) } if !ret["node2"].(bool) || !ret["node4"].(bool) { t.Fatal("test branch any type fail, result is unexpected") } // fail g = NewGraph[string, map[string]any]() err = g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node2": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node3": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node4", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node4": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node5", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{"node5": true}, nil })) if err != nil { t.Fatal(err) } err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { return "node2", nil }, map[string]bool{"node2": true, "node3": true})) if err != nil { t.Fatal(err) } err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in int /*error type*/) (endNode string, err error) { return "node4", nil }, map[string]bool{"node4": true, "node5": true})) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = g.AddEdge("node2", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node3", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node4", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node5", END) if err != nil { t.Fatal(err) } rr, err = g.Compile(context.Background()) if err != nil { t.Fatal(err) } _, err = rr.Invoke(context.Background(), "start") if err == nil || !strings.Contains(err.Error(), "runtime") { t.Fatal("test multi branch any type fail, haven't report runtime error") } _, err = rr.Stream(context.Background(), "start") if err == nil || !strings.Contains(err.Error(), "runtime") { t.Fatal("test multi branch any type fail, haven't report runtime error") } } func TestAnyTypeWithKey(t *testing.T) { g := NewGraph[any, map[string]any]() err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil }), WithInputKey("node1")) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil }), WithOutputKey("node2")) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = g.AddEdge("node1", "node2") if err != nil { t.Fatal(err) } err = g.AddEdge("node2", END) if err != nil { t.Fatal(err) } r, err := g.Compile(context.Background()) if err != nil { t.Fatal(err) } result, err := r.Invoke(context.Background(), map[string]any{"node1": "start"}) if err != nil { t.Fatal(err) } if result["node2"] != "startnode1node2" { t.Fatal("test any type with key fail, result is unexpected") } streamResult, err := r.Stream(context.Background(), map[string]any{"node1": "start"}) if err != nil { t.Fatal(err) } ret, err := concatStreamReader(streamResult) if err != nil { t.Fatal(err) } if ret["node2"] != "startnode1node2" { t.Fatal("test any type with key fail, result is unexpected") } } func TestInputKey(t *testing.T) { g := NewGraph[map[string]any, map[string]any]() err := g.AddChatTemplateNode("1", prompt.FromMessages(schema.FString, schema.UserMessage("{var1}")), WithOutputKey("1"), WithInputKey("1")) if err != nil { t.Fatal(err) } err = g.AddChatTemplateNode("2", prompt.FromMessages(schema.FString, schema.UserMessage("{var2}")), WithOutputKey("2"), WithInputKey("2")) if err != nil { t.Fatal(err) } err = g.AddChatTemplateNode("3", prompt.FromMessages(schema.FString, schema.UserMessage("{var3}")), WithOutputKey("3"), WithInputKey("3")) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "1") if err != nil { t.Fatal(err) } err = g.AddEdge(START, "2") if err != nil { t.Fatal(err) } err = g.AddEdge(START, "3") if err != nil { t.Fatal(err) } err = g.AddEdge("1", END) if err != nil { t.Fatal(err) } err = g.AddEdge("2", END) if err != nil { t.Fatal(err) } err = g.AddEdge("3", END) if err != nil { t.Fatal(err) } r, err := g.Compile(context.Background(), WithMaxRunSteps(100)) if err != nil { t.Fatal(err) } ctx := context.Background() result, err := r.Invoke(ctx, map[string]any{ "1": map[string]any{"var1": "a"}, "2": map[string]any{"var2": "b"}, "3": map[string]any{"var3": "c"}, }) if err != nil { t.Fatal(err) } if result["1"].([]*schema.Message)[0].Content != "a" || result["2"].([]*schema.Message)[0].Content != "b" || result["3"].([]*schema.Message)[0].Content != "c" { t.Fatal("invoke different") } sr, sw := schema.Pipe[map[string]any](10) sw.Send(map[string]any{"1": map[string]any{"var1": "a"}}, nil) sw.Send(map[string]any{"2": map[string]any{"var2": "b"}}, nil) sw.Send(map[string]any{"3": map[string]any{"var3": "c"}}, nil) sw.Close() streamResult, err := r.Transform(ctx, sr) if err != nil { t.Fatal(err) } defer streamResult.Close() result = make(map[string]any) for { chunk, err := streamResult.Recv() if err == io.EOF { break } if err != nil { t.Fatal(err) } for k, v := range chunk { result[k] = v } } if result["1"].([]*schema.Message)[0].Content != "a" || result["2"].([]*schema.Message)[0].Content != "b" || result["3"].([]*schema.Message)[0].Content != "c" { t.Fatal("transform different") } } func TestTransferTask(t *testing.T) { in := [][]string{ { "1", "2", }, { "3", "4", "5", "6", }, { "5", "6", "7", }, { "7", "8", }, { "8", }, } invertedEdges := map[string][]string{ "1": {"3", "4"}, "2": {"5", "6"}, "3": {"5"}, "4": {"6"}, "5": {"7"}, "7": {"8"}, } in = transferTask(in, invertedEdges) if !reflect.DeepEqual( [][]string{ { "1", }, { "3", "2", }, { "5", }, { "7", "4", }, { "8", "6", }, }, in) { t.Fatal("not equal") } } func TestPregelEnd(t *testing.T) { g := NewGraph[string, string]() err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node1", nil })) if err != nil { t.Fatal(err) } err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node2", nil })) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = g.AddEdge("node1", END) if err != nil { t.Fatal(err) } err = g.AddEdge("node1", "node2") if err != nil { t.Fatal(err) } err = g.AddEdge("node2", END) if err != nil { t.Fatal(err) } runner, err := g.Compile(context.Background()) if err != nil { t.Fatal(err) } out, err := runner.Invoke(context.Background(), "") if err != nil { t.Fatal(err) } if out != "node1" { t.Fatal("graph output is unexpected") } } type cb struct { gInfo *GraphInfo } func (c *cb) OnFinish(ctx context.Context, info *GraphInfo) { c.gInfo = info } func TestGraphCompileCallback(t *testing.T) { t.Run("graph compile callback", func(t *testing.T) { type s struct{} g := NewGraph[map[string]any, map[string]any](WithGenLocalState(func(ctx context.Context) *s { return &s{} })) lambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node1", nil }) lambdaOpts := []GraphAddNodeOpt{WithNodeName("lambda_1"), WithInputKey("input_key")} err := g.AddLambdaNode("node1", lambda, lambdaOpts...) assert.NoError(t, err) err = g.AddPassthroughNode("pass1") assert.NoError(t, err) err = g.AddPassthroughNode("pass2") assert.NoError(t, err) condition := func(ctx context.Context, input string) (string, error) { return input, nil } branch := NewGraphBranch(condition, map[string]bool{"pass1": true, "pass2": true}) err = g.AddBranch("node1", branch) assert.NoError(t, err) err = g.AddEdge(START, "node1") assert.NoError(t, err) lambda2 := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node2", nil }) lambdaOpts2 := []GraphAddNodeOpt{WithNodeName("lambda_2")} subSubGraph := NewGraph[string, string]() err = subSubGraph.AddLambdaNode("sub1", lambda2, lambdaOpts2...) assert.NoError(t, err) err = subSubGraph.AddEdge(START, "sub1") assert.NoError(t, err) err = subSubGraph.AddEdge("sub1", END) assert.NoError(t, err) subGraph := NewGraph[string, string]() var ssGraphCompileOpts []GraphCompileOption ssGraphOpts := []GraphAddNodeOpt{WithGraphCompileOptions(ssGraphCompileOpts...)} err = subGraph.AddGraphNode("sub_sub_1", subSubGraph, ssGraphOpts...) assert.NoError(t, err) err = subGraph.AddEdge(START, "sub_sub_1") assert.NoError(t, err) err = subGraph.AddEdge("sub_sub_1", END) assert.NoError(t, err) subGraphCompileOpts := []GraphCompileOption{WithMaxRunSteps(2), WithGraphName("sub_graph")} subGraphOpts := []GraphAddNodeOpt{WithGraphCompileOptions(subGraphCompileOpts...)} err = g.AddGraphNode("sub_graph", subGraph, subGraphOpts...) assert.NoError(t, err) err = g.AddEdge("pass1", "sub_graph") assert.NoError(t, err) err = g.AddEdge("pass2", "sub_graph") assert.NoError(t, err) lambda3 := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node3", nil }) lambdaOpts3 := []GraphAddNodeOpt{WithNodeName("lambda_3"), WithOutputKey("lambda_3")} err = g.AddLambdaNode("node3", lambda3, lambdaOpts3...) assert.NoError(t, err) lambda4 := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node4", nil }) lambdaOpts4 := []GraphAddNodeOpt{WithNodeName("lambda_4"), WithOutputKey("lambda_4")} err = g.AddLambdaNode("node4", lambda4, lambdaOpts4...) assert.NoError(t, err) err = g.AddEdge("sub_graph", "node3") assert.NoError(t, err) err = g.AddEdge("sub_graph", "node4") assert.NoError(t, err) err = g.AddEdge("node3", END) assert.NoError(t, err) err = g.AddEdge("node4", END) assert.NoError(t, err) c := &cb{} opt := []GraphCompileOption{WithGraphCompileCallbacks(c), WithGraphName("top_level")} _, err = g.Compile(context.Background(), opt...) assert.NoError(t, err) expected := &GraphInfo{ CompileOptions: opt, Nodes: map[string]GraphNodeInfo{ "node1": { Component: ComponentOfLambda, Instance: lambda, GraphAddNodeOpts: lambdaOpts, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "lambda_1", InputKey: "input_key", }, "pass1": { Component: ComponentOfPassthrough, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "", }, "pass2": { Component: ComponentOfPassthrough, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "", }, "sub_graph": { Component: ComponentOfGraph, Instance: subGraph, GraphAddNodeOpts: subGraphOpts, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "", GraphInfo: &GraphInfo{ CompileOptions: subGraphCompileOpts, Nodes: map[string]GraphNodeInfo{ "sub_sub_1": { Component: ComponentOfGraph, Instance: subSubGraph, GraphAddNodeOpts: ssGraphOpts, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "", GraphInfo: &GraphInfo{ CompileOptions: ssGraphCompileOpts, Nodes: map[string]GraphNodeInfo{ "sub1": { Component: ComponentOfLambda, Instance: lambda2, GraphAddNodeOpts: lambdaOpts2, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "lambda_2", }, }, Edges: map[string][]string{ START: {"sub1"}, "sub1": {END}, }, DataEdges: map[string][]string{ START: {"sub1"}, "sub1": {END}, }, Branches: map[string][]GraphBranch{}, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), }, }, }, Edges: map[string][]string{ START: {"sub_sub_1"}, "sub_sub_1": {END}, }, DataEdges: map[string][]string{ START: {"sub_sub_1"}, "sub_sub_1": {END}, }, Branches: map[string][]GraphBranch{}, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "sub_graph", }, }, "node3": { Component: ComponentOfLambda, Instance: lambda3, GraphAddNodeOpts: lambdaOpts3, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "lambda_3", OutputKey: "lambda_3", }, "node4": { Component: ComponentOfLambda, Instance: lambda4, GraphAddNodeOpts: lambdaOpts4, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), Name: "lambda_4", OutputKey: "lambda_4", }, }, Edges: map[string][]string{ START: {"node1"}, "pass1": {"sub_graph"}, "pass2": {"sub_graph"}, "sub_graph": {"node3", "node4"}, "node3": {END}, "node4": {END}, }, DataEdges: map[string][]string{ START: {"node1"}, "pass1": {"sub_graph"}, "pass2": {"sub_graph"}, "sub_graph": {"node3", "node4"}, "node3": {END}, "node4": {END}, }, Branches: map[string][]GraphBranch{ "node1": {*branch}, }, InputType: reflect.TypeOf(map[string]any{}), OutputType: reflect.TypeOf(map[string]any{}), Name: "top_level", } stateFn := c.gInfo.GenStateFn assert.NotNil(t, stateFn) assert.Equal(t, &s{}, stateFn(context.Background())) assert.Equal(t, 1, len(c.gInfo.NewGraphOptions)) c.gInfo.NewGraphOptions = nil c.gInfo.GenStateFn = nil actualCompileOptions := newGraphCompileOptions(c.gInfo.CompileOptions...) expectedCompileOptions := newGraphCompileOptions(expected.CompileOptions...) assert.Equal(t, len(expectedCompileOptions.callbacks), len(actualCompileOptions.callbacks)) assert.Same(t, expectedCompileOptions.callbacks[0], actualCompileOptions.callbacks[0]) actualCompileOptions.callbacks = nil actualCompileOptions.origOpts = nil expectedCompileOptions.callbacks = nil expectedCompileOptions.origOpts = nil assert.Equal(t, expectedCompileOptions, actualCompileOptions) c.gInfo.CompileOptions = nil expected.CompileOptions = nil assert.Equal(t, expected.Branches["node1"][0].endNodes, c.gInfo.Branches["node1"][0].endNodes) assert.Equal(t, expected.Branches["node1"][0].inputType, c.gInfo.Branches["node1"][0].inputType) expected.Branches["node1"] = []GraphBranch{} c.gInfo.Branches["node1"] = []GraphBranch{} assert.Equal(t, expected, c.gInfo) }) } func TestCheckAddEdge(t *testing.T) { g := NewGraph[string, string]() err := g.AddPassthroughNode("1") if err != nil { t.Fatal(err) } err = g.AddPassthroughNode("2") if err != nil { t.Fatal(err) } err = g.AddEdge("1", "2") if err != nil { t.Fatal(err) } err = g.AddEdge("1", "2") if err == nil { t.Fatal("add edge repeatedly haven't report error") } } func TestStartWithEnd(t *testing.T) { g := NewGraph[string, string]() err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) if err != nil { t.Fatal(err) } err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { return END, nil }, map[string]bool{"1": true, END: true})) if err != nil { t.Fatal(err) } r, err := g.Compile(context.Background()) if err != nil { t.Fatal(err) } sr, sw := schema.Pipe[string](1) sw.Send("test", nil) sw.Close() result, err := r.Transform(context.Background(), sr) if err != nil { t.Fatal(err) } for { chunk, err := result.Recv() if err == io.EOF { break } if err != nil { t.Fatal(err) } if chunk != "test" { t.Fatal("result is out of expect") } } } func TestToString(t *testing.T) { ps := runTypePregel.String() assert.Equal(t, "Pregel", ps) ds := runTypeDAG assert.Equal(t, "DAG", ds.String()) } func TestInputKeyError(t *testing.T) { g := NewGraph[map[string]any, string]() err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }), WithInputKey("node1")) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "node1") if err != nil { t.Fatal(err) } err = g.AddEdge("node1", END) if err != nil { t.Fatal(err) } ctx := context.Background() r, err := g.Compile(ctx) if err != nil { t.Fatal(err) } // invoke _, err = r.Invoke(ctx, map[string]any{"unknown": "123"}) if err == nil || !strings.Contains(err.Error(), "cannot find input key: node1") { t.Fatal("cannot report input key error correctly") } // transform sr, sw := schema.Pipe[map[string]any](1) sw.Send(map[string]any{"unknown": "123"}, nil) sw.Close() _, err = r.Transform(ctx, sr) if err == nil || !strings.Contains(err.Error(), "stream reader is empty, concat fail") { t.Fatal("cannot report input key error correctly") } } func TestContextCancel(t *testing.T) { ctx := context.Background() g := NewGraph[string, string]() err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "1") if err != nil { t.Fatal(err) } err = g.AddEdge("1", END) if err != nil { t.Fatal(err) } r, err := g.Compile(ctx) if err != nil { t.Fatal(err) } ctx, cancel := context.WithCancel(ctx) cancel() _, err = r.Invoke(ctx, "test") if !strings.Contains(err.Error(), "context has been canceled") { t.Fatal("graph have not returned canceled error") } } func TestDAGStart(t *testing.T) { g := NewGraph[map[string]any, map[string]any]() err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) { return map[string]any{"1": "1"}, nil })) assert.NoError(t, err) err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) { return input, nil })) assert.NoError(t, err) err = g.AddEdge(START, "1") assert.NoError(t, err) err = g.AddEdge("1", "2") assert.NoError(t, err) err = g.AddEdge(START, "2") assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) r, err := g.Compile(context.Background(), WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) result, err := r.Invoke(context.Background(), map[string]any{"start": "start"}) assert.NoError(t, err) assert.Equal(t, map[string]any{"start": "start", "1": "1"}, result) } func concatLambda(s string) *Lambda { return InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + s, nil }) } func mapLambda(k, v string) *Lambda { return InvokableLambda(func(ctx context.Context, input map[string]string) (output map[string]string, err error) { return map[string]string{ k: v, }, nil }) } func TestBaseDAGBranch(t *testing.T) { g := NewGraph[string, string]() err := g.AddLambdaNode("1", concatLambda("1")) assert.NoError(t, err) err = g.AddLambdaNode("2", concatLambda("2")) assert.NoError(t, err) err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { if len(in) > 3 { return "2", nil } return "1", nil }, map[string]bool{"1": true, "2": true})) assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) result, err := r.Invoke(ctx, "hi") assert.NoError(t, err) assert.Equal(t, "hi1", result) } func TestMultiDAGBranch(t *testing.T) { g := NewGraph[map[string]string, map[string]string]() err := g.AddLambdaNode("1", mapLambda("1", "1")) assert.NoError(t, err) err = g.AddLambdaNode("2", mapLambda("2", "2")) assert.NoError(t, err) err = g.AddLambdaNode("3", mapLambda("3", "3")) assert.NoError(t, err) err = g.AddLambdaNode("4", mapLambda("4", "4")) assert.NoError(t, err) err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { if len(in["input"]) > 3 { return "2", nil } return "1", nil }, map[string]bool{"1": true, "2": true})) err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { if len(in["input"]) > 3 { return "4", nil } return "3", nil }, map[string]bool{"3": true, "4": true})) assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) err = g.AddEdge("4", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) result, err := r.Invoke(ctx, map[string]string{"input": "hi"}) assert.NoError(t, err) assert.Equal(t, map[string]string{ "1": "1", "3": "3", }, result) } func TestCrossDAGBranch(t *testing.T) { g := NewGraph[map[string]string, map[string]string]() err := g.AddLambdaNode("1", mapLambda("1", "1")) assert.NoError(t, err) err = g.AddLambdaNode("2", mapLambda("2", "2")) assert.NoError(t, err) err = g.AddLambdaNode("3", mapLambda("3", "3")) assert.NoError(t, err) err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { if len(in["input"]) > 3 { return "2", nil } return "1", nil }, map[string]bool{"1": true, "2": true})) err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { if len(in["input"]) > 3 { return "3", nil } return "2", nil }, map[string]bool{"2": true, "3": true})) assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) err = g.AddEdge("2", END) assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) result, err := r.Invoke(ctx, map[string]string{"input": "hi"}) assert.NoError(t, err) assert.Equal(t, map[string]string{ "1": "1", "2": "2", }, result) } func TestNestedDAGBranch(t *testing.T) { g := NewGraph[string, string]() err := g.AddLambdaNode("1", concatLambda("1")) assert.NoError(t, err) err = g.AddLambdaNode("2", concatLambda("2")) assert.NoError(t, err) err = g.AddLambdaNode("3", concatLambda("3")) assert.NoError(t, err) err = g.AddLambdaNode("4", concatLambda("4")) assert.NoError(t, err) err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { if len(in) > 3 { return "2", nil } return "1", nil }, map[string]bool{"1": true, "2": true})) err = g.AddBranch("2", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { if len(in) > 10 { return "4", nil } return "3", nil }, map[string]bool{"3": true, "4": true})) assert.NoError(t, err) err = g.AddEdge("1", END) assert.NoError(t, err) err = g.AddEdge("3", END) assert.NoError(t, err) err = g.AddEdge("4", END) assert.NoError(t, err) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) result, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, "hello23", result) result, err = r.Invoke(ctx, "hi") assert.NoError(t, err) assert.Equal(t, "hi1", result) result, err = r.Invoke(ctx, "hellohello") assert.NoError(t, err) assert.Equal(t, "hellohello24", result) } func TestHandlerTypeValidate(t *testing.T) { g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state string) { return "" })) // passthrough pre fail err := g.AddPassthroughNode("1", WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil })) assert.ErrorContains(t, err, "passthrough node[1]'s pre handler type isn't any") g.buildError = nil // passthrough pre fail with input key err = g.AddPassthroughNode("1", WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil }), WithInputKey("input")) assert.ErrorContains(t, err, "node[1]'s pre handler type[string] is different from its input type[map[string]interface {}]") g.buildError = nil // passthrough post fail err = g.AddPassthroughNode("1", WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil })) assert.ErrorContains(t, err, "passthrough node[1]'s post handler type isn't any") g.buildError = nil // passthrough post fail with input key err = g.AddPassthroughNode("1", WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil }), WithInputKey("input")) assert.ErrorContains(t, err, "passthrough node[1]'s post handler type isn't any") g.buildError = nil // passthrough pre success err = g.AddPassthroughNode("1", WithStatePreHandler(func(ctx context.Context, in any, state string) (any, error) { return "", nil })) assert.NoError(t, err) // passthrough pre success with input key err = g.AddPassthroughNode("2", WithStatePreHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) { return nil, nil }), WithInputKey("input")) assert.NoError(t, err) // passthrough post success err = g.AddPassthroughNode("3", WithStatePostHandler(func(ctx context.Context, in any, state string) (any, error) { return "", nil })) assert.NoError(t, err) // passthrough post success with output key err = g.AddPassthroughNode("4", WithStatePostHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) { return nil, nil }), WithOutputKey("output")) assert.NoError(t, err) // common node pre fail err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input int) (output int, err error) { return 0, nil }), WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil })) assert.ErrorContains(t, err, "node[5]'s pre handler type[string] is different from its input type[int]") g.buildError = nil // common node post fail err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input int) (output int, err error) { return 0, nil }), WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil })) assert.ErrorContains(t, err, "node[5]'s post handler type[string] is different from its output type[int]") g.buildError = nil // common node pre success err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil }), WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil })) assert.NoError(t, err) // common node post success err = g.AddLambdaNode("6", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil }), WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) { return "", nil })) assert.NoError(t, err) // pre state fail err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil }), WithStatePreHandler(func(ctx context.Context, in string, state int) (string, error) { return "", nil })) assert.ErrorContains(t, err, "node[7]'s pre handler state type[int] is different from graph[string]") g.buildError = nil // post state fail err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil }), WithStatePostHandler(func(ctx context.Context, in string, state int) (string, error) { return "", nil })) assert.ErrorContains(t, err, "node[7]'s post handler state type[int] is different from graph[string]") g.buildError = nil // common pre success with input key err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil }), WithStatePreHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) { return nil, nil }), WithInputKey("input")) assert.NoError(t, err) // common post success with output key err = g.AddLambdaNode("8", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil }), WithStatePostHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) { return nil, nil }), WithOutputKey("output")) assert.NoError(t, err) } func TestSetFanInMergeConfig_RealStreamNode(t *testing.T) { for _, triggerMode := range []NodeTriggerMode{AnyPredecessor, AllPredecessor} { t.Run(string(triggerMode), func(t *testing.T) { g := NewGraph[int, map[string]any]() // Add two stream nodes that output streams of int slices err := g.AddLambdaNode("s1", StreamableLambda(func(ctx context.Context, input int) (*schema.StreamReader[int], error) { sr, sw := schema.Pipe[int](2) sw.Send(input+1, nil) sw.Send(input+2, nil) sw.Close() return sr, nil }), WithOutputKey("s1")) assert.NoError(t, err) err = g.AddLambdaNode("s2", StreamableLambda(func(ctx context.Context, input int) (*schema.StreamReader[int], error) { sr, sw := schema.Pipe[int](2) sw.Send(input+10, nil) sw.Send(input+20, nil) sw.Close() return sr, nil }), WithOutputKey("s2")) assert.NoError(t, err) // Connect edges: START -> s1, START -> s2, s1 -> END, s2 -> END err = g.AddEdge(START, "s1") assert.NoError(t, err) err = g.AddEdge(START, "s2") assert.NoError(t, err) err = g.AddEdge("s1", END) assert.NoError(t, err) err = g.AddEdge("s2", END) assert.NoError(t, err) r, err := g.Compile(context.Background(), WithNodeTriggerMode(triggerMode), WithFanInMergeConfig(map[string]FanInMergeConfig{END: {StreamMergeWithSourceEOF: true}})) assert.NoError(t, err) // Run the graph in stream mode and check for SourceEOF events sr, err := r.Stream(context.Background(), 1) assert.NoError(t, err) merged := make(map[string]map[int]bool) var sourceEOFCount int sourceNames := make(map[string]bool) for { m, e := sr.Recv() if e != nil { if name, ok := schema.GetSourceName(e); ok { sourceEOFCount++ sourceNames[name] = true continue } if e == io.EOF { break } assert.NoError(t, e) } for k, v := range m { if merged[k] == nil { merged[k] = make(map[int]bool) } merged[k][v.(int)] = true } } // The merged map should contain both results assert.Equal(t, map[string]map[int]bool{"s1": {2: true, 3: true}, "s2": {11: true, 21: true}}, merged) assert.Equal(t, 2, sourceEOFCount, "should receive SourceEOF for each input stream when StreamMergeWithSourceEOF is true") assert.True(t, sourceNames["s1"], "should receive SourceEOF from s1") assert.True(t, sourceNames["s2"], "should receive SourceEOF from s2") }) } } func TestFindLoops(t *testing.T) { tests := []struct { name string startNodes []string chanCalls map[string]*chanCall expected [][]string }{ { name: "Graph without cycles", startNodes: []string{"A"}, chanCalls: map[string]*chanCall{ "A": { controls: []string{"B", "C"}, }, "B": { controls: []string{"D"}, }, "C": { controls: []string{"E"}, }, "D": { controls: []string{}, }, "E": { controls: []string{}, }, }, expected: [][]string{}, }, { name: "Graph with self-loop", startNodes: []string{"A"}, chanCalls: map[string]*chanCall{ "A": { controls: []string{"A", "B"}, }, "B": { controls: []string{}, }, }, expected: [][]string{{"A", "A"}}, }, { name: "Graph with simple cycle", startNodes: []string{"A", "B", "C"}, chanCalls: map[string]*chanCall{ "A": { controls: []string{"B"}, }, "B": { controls: []string{"C"}, }, "C": { controls: []string{"A"}, }, }, expected: [][]string{{"A", "B", "C", "A"}}, }, { name: "Graph with multiple cycles", startNodes: []string{"A", "B", "C", "D", "E", "F"}, chanCalls: map[string]*chanCall{ "A": { controls: []string{"B", "D"}, }, "B": { controls: []string{"C"}, }, "C": { controls: []string{"B"}, }, "D": { controls: []string{"E"}, }, "E": { controls: []string{"F"}, }, "F": { controls: []string{"D"}, }, }, expected: [][]string{{"B", "C", "B"}, {"D", "E", "F", "D"}}, }, { name: "Graph with branch cycle", startNodes: []string{"A", "C"}, chanCalls: map[string]*chanCall{ "A": { controls: []string{"B"}, writeToBranches: []*GraphBranch{ { endNodes: map[string]bool{ "C": true, }, }, }, }, "B": { controls: []string{}, }, "C": { controls: []string{"A"}, }, }, expected: [][]string{{"A", "C", "A"}}, }, { name: "Empty graph", startNodes: []string{}, chanCalls: map[string]*chanCall{}, expected: [][]string{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { loops := findLoops(tt.startNodes, tt.chanCalls) assert.Equal(t, len(tt.expected), len(loops)) if len(tt.expected) > 0 { normalizedExpected := normalizeLoops(tt.expected) normalizedActual := normalizeLoops(loops) assert.Equal(t, normalizedExpected, normalizedActual) } }) } } func normalizeLoops(loops [][]string) []string { result := make([]string, 0, len(loops)) for _, loop := range loops { if len(loop) == 0 { continue } normalizedLoop := make([]string, len(loop)) copy(normalizedLoop, loop) if normalizedLoop[0] != normalizedLoop[len(normalizedLoop)-1] { normalizedLoop = append(normalizedLoop, normalizedLoop[0]) } minIdx := 0 for i := 1; i < len(normalizedLoop)-1; i++ { if normalizedLoop[i] < normalizedLoop[minIdx] { minIdx = i } } canonicalLoop := "" for i := 0; i < len(normalizedLoop)-1; i++ { idx := (minIdx + i) % (len(normalizedLoop) - 1) canonicalLoop += normalizedLoop[idx] + "," } canonicalLoop += normalizedLoop[minIdx] result = append(result, canonicalLoop) } sort.Strings(result) return result } func TestPrintTasks(t *testing.T) { var ts []*task assert.Equal(t, "[]", printTask(ts)) ts = []*task{{nodeKey: "1"}} assert.Equal(t, "[1]", printTask(ts)) ts = []*task{{nodeKey: "1"}, {nodeKey: "2"}, {nodeKey: "3"}} assert.Equal(t, "[1, 2, 3]", printTask(ts)) } func TestSkipBranch(t *testing.T) { g := NewGraph[string, string]() _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) _ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) _ = g.AddEdge(START, "1") _ = g.AddBranch("1", NewGraphMultiBranch(func(ctx context.Context, in string) (endNode map[string]bool, err error) { return map[string]bool{}, nil }, map[string]bool{"2": true})) _ = g.AddEdge("2", END) ctx := context.Background() r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) _, err = r.Invoke(ctx, "input") assert.ErrorContains(t, err, "[GraphRunError] no tasks to execute, last completed nodes: [1]") g = NewGraph[string, string]() _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) _ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })) _ = g.AddEdge(START, "1") _ = g.AddBranch("1", NewGraphMultiBranch(func(ctx context.Context, in string) (endNode map[string]bool, err error) { return map[string]bool{}, nil }, map[string]bool{"2": true})) _ = g.AddEdge("2", END) _ = g.AddEdge(START, "2") r, err = g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) assert.NoError(t, err) result, err := r.Invoke(ctx, "input") assert.NoError(t, err) assert.Equal(t, "input", result) } func TestGetStateInGraphCallback(t *testing.T) { g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (s *state) { return &state{} })) assert.NoError(t, g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }))) assert.NoError(t, g.AddEdge(START, "1")) assert.NoError(t, g.AddEdge("1", END)) ctx := context.Background() r, err := g.Compile(ctx) assert.NoError(t, err) _, err = r.Invoke(ctx, "input", WithCallbacks(&testGraphStateCallbackHandler{t: t})) assert.NoError(t, err) } type state struct { A string } type testGraphStateCallbackHandler struct { t *testing.T } func (t *testGraphStateCallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { assert.NoError(t.t, ProcessState[*state](ctx, func(ctx context.Context, s *state) error { s.A = "test" return nil })) return ctx } func (t *testGraphStateCallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { return ctx } func (t *testGraphStateCallbackHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { return ctx } func (t *testGraphStateCallbackHandler) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { return ctx } func (t *testGraphStateCallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { return ctx } func TestUniqueSlice(t *testing.T) { assert.Equal(t, []string{"a", "b", "c"}, uniqueSlice([]string{"a", "b", "a", "c", "b"})) assert.Equal(t, []string{}, uniqueSlice([]string{})) } ================================================ FILE: compose/interrupt.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "github.com/google/uuid" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/schema" ) // WithInterruptBeforeNodes instructs to interrupt before the given nodes. func WithInterruptBeforeNodes(nodes []string) GraphCompileOption { return func(options *graphCompileOptions) { options.interruptBeforeNodes = nodes } } // WithInterruptAfterNodes instructs to interrupt after the given nodes. func WithInterruptAfterNodes(nodes []string) GraphCompileOption { return func(options *graphCompileOptions) { options.interruptAfterNodes = nodes } } // Deprecated: prefer Interrupt/StatefulInterrupt and CompositeInterrupt. // If you need to pass the legacy error into CompositeInterrupt, wrap it using WrapInterruptAndRerunIfNeeded first. var InterruptAndRerun = deprecatedInterruptAndRerun var deprecatedInterruptAndRerun = errors.New("interrupt and rerun") // NewInterruptAndRerunErr creates a legacy interrupt-and-rerun error. // Deprecated: prefer Interrupt(ctx, info) or StatefulInterrupt(ctx, info, state). // If passing into CompositeInterrupt, wrap using WrapInterruptAndRerunIfNeeded first. func NewInterruptAndRerunErr(extra any) error { return deprecatedInterruptAndRerunErr(extra) } func deprecatedInterruptAndRerunErr(extra any) error { return &core.InterruptSignal{InterruptInfo: core.InterruptInfo{ Info: extra, IsRootCause: true, }} } type wrappedInterruptAndRerun struct { ps Address inner error } func (w *wrappedInterruptAndRerun) Error() string { return fmt.Sprintf("interrupt and rerun at address %s: %s", w.ps.String(), w.inner.Error()) } func (w *wrappedInterruptAndRerun) Unwrap() error { return w.inner } // WrapInterruptAndRerunIfNeeded wraps the deprecated old interrupt errors, with the current execution address. // If the error is returned by either Interrupt, StatefulInterrupt or CompositeInterrupt, // it will be returned as-is without wrapping func WrapInterruptAndRerunIfNeeded(ctx context.Context, step AddressSegment, err error) error { addr := GetCurrentAddress(ctx) newAddr := append(append([]AddressSegment{}, addr...), step) if errors.Is(err, deprecatedInterruptAndRerun) { return &wrappedInterruptAndRerun{ ps: newAddr, inner: err, } } ire := &core.InterruptSignal{} if errors.As(err, &ire) { if ire.Address == nil { return &wrappedInterruptAndRerun{ ps: newAddr, inner: err, } } return ire } return fmt.Errorf("failed to wrap error as addressed InterruptAndRerun: %w", err) } // Interrupt creates a special error that signals the execution engine to interrupt // the current run at the component's specific address and save a checkpoint. // // This is the standard way for a single, non-composite component to signal a resumable interruption. // // - ctx: The context of the running component, used to retrieve the current execution address. // - info: User-facing information about the interrupt. This is not persisted but is exposed to the // calling application via the InterruptCtx to provide context (e.g., a reason for the pause). func Interrupt(ctx context.Context, info any) error { is, err := core.Interrupt(ctx, info, nil, nil) if err != nil { return err } return is } // StatefulInterrupt creates a special error that signals the execution engine to interrupt // the current run at the component's specific address and save a checkpoint. // // This is the standard way for a single, non-composite component to signal a resumable interruption. // // - ctx: The context of the running component, used to retrieve the current execution address. // - info: User-facing information about the interrupt. This is not persisted but is exposed to the // calling application via the InterruptCtx to provide context (e.g., a reason for the pause). // - state: The internal state that the interrupting component needs to persist to be able to resume // its work later. This state is saved in the checkpoint and will be provided back to the component // upon resumption via GetInterruptState. func StatefulInterrupt(ctx context.Context, info any, state any) error { is, err := core.Interrupt(ctx, info, state, nil) if err != nil { return err } return is } // CompositeInterrupt creates a special error that signals a composite interruption. // It is designed for "composite" nodes (like ToolsNode) that manage multiple, independent, // interruptible sub-processes. It bundles multiple sub-interrupt errors into a single error // that the engine can deconstruct into a flat list of resumable points. // // This function is robust and can handle several types of errors from sub-processes: // // - A `Interrupt` or `StatefulInterrupt` error from a simple component. // // - A nested `CompositeInterrupt` error from another composite component. // // - An error containing `InterruptInfo` returned by a `Runnable` (e.g., a Graph within a lambda node). // // - An error returned by \'WrapInterruptAndRerunIfNeeded\' for the legacy old interrupt and rerun error, // and for the error returned by the deprecated old interrupt errors. // // Parameters: // // - ctx: The context of the running composite node. // // - info: User-facing information for the composite node itself. Can be nil. // This info will be attached to InterruptInfo.RerunNodeExtra. // Provided mainly for compatibility purpose as the composite node itself // is not an interrupt point with interrupt ID, // which means it lacks enough reason to give a user-facing info. // // - state: The state for the composite node itself. Can be nil. // This could be useful when the composite node needs to restore state, // such as its input (e.g. ToolsNode). // // - errs: a list of errors emitted by sub-processes. // // NOTE: if the error you passed in is the deprecated old interrupt and rerun err, or an error returned by // the deprecated old interrupt function, you must wrap it using WrapInterruptAndRerunIfNeeded first // before passing them into this function. func CompositeInterrupt(ctx context.Context, info any, state any, errs ...error) error { if len(errs) == 0 { return StatefulInterrupt(ctx, info, state) } var cErrs []*core.InterruptSignal for _, err := range errs { wrapped := &wrappedInterruptAndRerun{} if errors.As(err, &wrapped) { inner := wrapped.Unwrap() if errors.Is(inner, deprecatedInterruptAndRerun) { id := uuid.NewString() cErrs = append(cErrs, &core.InterruptSignal{ ID: id, Address: wrapped.ps, InterruptInfo: core.InterruptInfo{ Info: nil, IsRootCause: true, }, }) continue } ire := &core.InterruptSignal{} if errors.As(err, &ire) { id := uuid.NewString() cErrs = append(cErrs, &core.InterruptSignal{ ID: id, Address: wrapped.ps, InterruptInfo: core.InterruptInfo{ Info: ire.InterruptInfo.Info, IsRootCause: ire.InterruptInfo.IsRootCause, }, InterruptState: core.InterruptState{ State: ire.InterruptState.State, }, }) } continue } ire := &core.InterruptSignal{} if errors.As(err, &ire) { cErrs = append(cErrs, ire) continue } ie := &interruptError{} if errors.As(err, &ie) { is := core.FromInterruptContexts(ie.Info.InterruptContexts) cErrs = append(cErrs, is) continue } return fmt.Errorf("composite interrupt but one of the sub error is not interrupt and rerun error: %w", err) } is, err := core.Interrupt(ctx, info, state, cErrs) if err != nil { return err } return is } // IsInterruptRerunError reports whether the error represents an interrupt-and-rerun // and returns any attached info. func IsInterruptRerunError(err error) (any, bool) { info, _, ok := isInterruptRerunError(err) return info, ok } func isInterruptRerunError(err error) (info any, state any, ok bool) { if errors.Is(err, deprecatedInterruptAndRerun) { return nil, nil, true } ire := &core.InterruptSignal{} if errors.As(err, &ire) { return ire.Info, ire.State, true } return nil, nil, false } // InterruptInfo aggregates interrupt metadata for composite or nested runs. type InterruptInfo struct { State any BeforeNodes []string AfterNodes []string RerunNodes []string RerunNodesExtra map[string]any SubGraphs map[string]*InterruptInfo InterruptContexts []*InterruptCtx } func init() { schema.RegisterName[*InterruptInfo]("_eino_compose_interrupt_info") } // AddressSegmentType defines the type of a segment in an execution address. type AddressSegmentType = core.AddressSegmentType const ( // AddressSegmentNode represents a segment of an address that corresponds to a graph node. AddressSegmentNode AddressSegmentType = "node" // AddressSegmentTool represents a segment of an address that corresponds to a specific tool call within a ToolsNode. AddressSegmentTool AddressSegmentType = "tool" // AddressSegmentRunnable represents a segment of an address that corresponds to an instance of the Runnable interface. // Currently the possible Runnable types are: Graph, Workflow and Chain. // Note that for sub-graphs added through AddGraphNode to another graph is not a Runnable. // So a AddressSegmentRunnable indicates a standalone Root level Graph, // or a Root level Graph inside a node such as Lambda node. AddressSegmentRunnable AddressSegmentType = "runnable" ) // Address represents a full, hierarchical address to a point in the execution structure. type Address = core.Address // AddressSegment represents a single segment in the hierarchical address of an execution point. // A sequence of AddressSegments uniquely identifies a location within a potentially nested structure. type AddressSegment = core.AddressSegment // InterruptCtx provides a complete, user-facing context for a single, resumable interrupt point. type InterruptCtx = core.InterruptCtx // ExtractInterruptInfo extracts InterruptInfo from an error if present. func ExtractInterruptInfo(err error) (info *InterruptInfo, existed bool) { if err == nil { return nil, false } var iE *interruptError if errors.As(err, &iE) { return iE.Info, true } var sIE *subGraphInterruptError if errors.As(err, &sIE) { return sIE.Info, true } return nil, false } type interruptError struct { Info *InterruptInfo } func (e *interruptError) Error() string { return fmt.Sprintf("interrupt happened, info: %+v", e.Info) } func (e *interruptError) GetInterruptContexts() []*InterruptCtx { if e.Info == nil { return nil } return e.Info.InterruptContexts } func isSubGraphInterrupt(err error) *subGraphInterruptError { if err == nil { return nil } var iE *subGraphInterruptError if errors.As(err, &iE) { return iE } return nil } type subGraphInterruptError struct { Info *InterruptInfo CheckPoint *checkpoint signal *core.InterruptSignal } func (e *subGraphInterruptError) Error() string { return fmt.Sprintf("interrupt happened, info: %+v", e.Info) } func isInterruptError(err error) bool { if _, ok := ExtractInterruptInfo(err); ok { return true } if info := isSubGraphInterrupt(err); info != nil { return true } if _, ok := IsInterruptRerunError(err); ok { return true } return false } ================================================ FILE: compose/introspect.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "reflect" "github.com/cloudwego/eino/components" ) // GraphNodeInfo the info which end users pass in when they are adding nodes to graph. type GraphNodeInfo struct { Component components.Component Instance any GraphAddNodeOpts []GraphAddNodeOpt InputType, OutputType reflect.Type // mainly for lambda, whose input and output types cannot be inferred by component type Name string InputKey, OutputKey string GraphInfo *GraphInfo Mappings []*FieldMapping } // GraphInfo the info which end users pass in when they are compiling a graph. // it is used in compile callback for user to get the node info and instance. // you may need all details info of the graph for observation. type GraphInfo struct { CompileOptions []GraphCompileOption Nodes map[string]GraphNodeInfo // node key -> node info Edges map[string][]string // edge start node key -> edge end node key, control edges DataEdges map[string][]string Branches map[string][]GraphBranch // branch start node key -> branch InputType, OutputType reflect.Type Name string NewGraphOptions []NewGraphOption GenStateFn func(context.Context) any } // GraphCompileCallback is the callback which will be called when graph compilation finishes. type GraphCompileCallback interface { OnFinish(ctx context.Context, info *GraphInfo) } ================================================ FILE: compose/pregel.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import "fmt" func pregelChannelBuilder(_ []string, _ []string, _ func() any, _ func() streamReader) channel { return &pregelChannel{Values: make(map[string]any)} } type pregelChannel struct { Values map[string]any mergeConfig FanInMergeConfig } func (ch *pregelChannel) setMergeConfig(cfg FanInMergeConfig) { ch.mergeConfig.StreamMergeWithSourceEOF = cfg.StreamMergeWithSourceEOF } func (ch *pregelChannel) load(c channel) error { dc, ok := c.(*pregelChannel) if !ok { return fmt.Errorf("load pregel channel fail, got %T, want *pregelChannel", c) } ch.Values = dc.Values return nil } func (ch *pregelChannel) convertValues(fn func(map[string]any) error) error { return fn(ch.Values) } func (ch *pregelChannel) reportValues(ins map[string]any) error { for k, v := range ins { ch.Values[k] = v } return nil } func (ch *pregelChannel) get(isStream bool, name string, edgeHandler *edgeHandlerManager) ( any, bool, error) { if len(ch.Values) == 0 { return nil, false, nil } defer func() { ch.Values = map[string]any{} }() values := make([]any, len(ch.Values)) names := make([]string, len(ch.Values)) i := 0 for k, v := range ch.Values { resolvedV, err := edgeHandler.handle(k, name, v, isStream) if err != nil { return nil, false, err } values[i] = resolvedV names[i] = k i++ } if len(values) == 1 { return values[0], true, nil } // merge mergeOpts := &mergeOptions{ streamMergeWithSourceEOF: ch.mergeConfig.StreamMergeWithSourceEOF, names: names, } v, err := mergeValues(values, mergeOpts) if err != nil { return nil, false, err } return v, true, nil } func (ch *pregelChannel) reportSkip(_ []string) bool { return false } func (ch *pregelChannel) reportDependencies(_ []string) { return } ================================================ FILE: compose/resume.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "github.com/cloudwego/eino/internal/core" ) // GetInterruptState provides a type-safe way to check for and retrieve the persisted state from a previous interruption. // It is the primary function a component should use to understand its past state. // // It returns three values: // - wasInterrupted (bool): True if the node was part of a previous interruption, regardless of whether state was provided. // - state (T): The typed state object, if it was provided and matches type `T`. // - hasState (bool): True if state was provided during the original interrupt and successfully cast to type `T`. func GetInterruptState[T any](ctx context.Context) (wasInterrupted bool, hasState bool, state T) { return core.GetInterruptState[T](ctx) } // GetResumeContext checks if the current component is the target of a resume operation // and retrieves any data provided by the user for that resumption. // // This function is typically called *after* a component has already determined it is in a // resumed state by calling GetInterruptState. // // It returns three values: // - isResumeFlow: A boolean that is true if the current component's address was explicitly targeted // by a call to Resume() or ResumeWithData(). // - hasData: A boolean that is true if data was provided for this component (i.e., not nil). // - data: The typed data provided by the user. // // ### How to Use This Function: A Decision Framework // // The correct usage pattern depends on the application's desired resume strategy. // // #### Strategy 1: Implicit "Resume All" // In some use cases, any resume operation implies that *all* interrupted points should proceed. // For example, if an application's UI only provides a single "Continue" button for a set of // interruptions. In this model, a component can often just use `GetInterruptState` to see if // `wasInterrupted` is true and then proceed with its logic, as it can assume it is an intended target. // It may still call `GetResumeContext` to check for optional data, but the `isResumeFlow` flag is less critical. // // #### Strategy 2: Explicit "Targeted Resume" (Most Common) // For applications with multiple, distinct interrupt points that must be resumed independently, it is // crucial to differentiate which point is being resumed. This is the primary use case for the `isResumeFlow` flag. // - If `isResumeFlow` is `true`: Your component is the explicit target. You should consume // the `data` (if any) and complete your work. // - If `isResumeFlow` is `false`: Another component is the target. You MUST re-interrupt // (e.g., by returning `StatefulInterrupt(...)`) to preserve your state and allow the // resume signal to propagate. // // ### Guidance for Composite Components // // Composite components (like `Graph` or other `Runnable`s that contain sub-processes) have a dual role: // 1. Check for Self-Targeting: A composite component can itself be the target of a resume // operation, for instance, to modify its internal state. It may call `GetResumeContext` // to check for data targeted at its own address. // 2. Act as a Conduit: After checking for itself, its primary role is to re-execute its children, // allowing the resume context to flow down to them. It must not consume a resume signal // intended for one of its descendants. func GetResumeContext[T any](ctx context.Context) (isResumeFlow bool, hasData bool, data T) { return core.GetResumeContext[T](ctx) } // GetCurrentAddress returns the hierarchical address of the currently executing component. // The address is a sequence of segments, each identifying a structural part of the execution // like an agent, a graph node, or a tool call. This can be useful for logging or debugging. func GetCurrentAddress(ctx context.Context) Address { return core.GetCurrentAddress(ctx) } // Resume prepares a context for an "Explicit Targeted Resume" operation by targeting one or more // components without providing data. It is a convenience wrapper around BatchResumeWithData. // // This is useful when the act of resuming is itself the signal, and no extra data is needed. // The components at the provided addresses (interrupt IDs) will receive `isResumeFlow = true` // when they call `GetResumeContext`. func Resume(ctx context.Context, interruptIDs ...string) context.Context { resumeData := make(map[string]any, len(interruptIDs)) for _, addr := range interruptIDs { resumeData[addr] = nil } return BatchResumeWithData(ctx, resumeData) } // ResumeWithData prepares a context to resume a single, specific component with data. // It is the primary function for the "Explicit Targeted Resume" strategy when data is required. // It is a convenience wrapper around BatchResumeWithData. // The `interruptID` parameter is the unique interrupt ID of the target component. func ResumeWithData(ctx context.Context, interruptID string, data any) context.Context { return BatchResumeWithData(ctx, map[string]any{interruptID: data}) } // BatchResumeWithData is the core function for preparing a resume context. It injects a map // of resume targets and their corresponding data into the context. // // The `resumeData` map should contain the interrupt IDs (which are the string form of addresses) of the // components to be resumed as keys. The value can be the resume data for that component, or `nil` // if no data is needed (equivalent to using `Resume`). // // This function is the foundation for the "Explicit Targeted Resume" strategy. Components whose interrupt IDs // are present as keys in the map will receive `isResumeFlow = true` when they call `GetResumeContext`. func BatchResumeWithData(ctx context.Context, resumeData map[string]any) context.Context { return core.BatchResumeWithData(ctx, resumeData) } func getNodePath(ctx context.Context) (*NodePath, bool) { currentAddress := GetCurrentAddress(ctx) if len(currentAddress) == 0 { return nil, false } nodePath := make([]string, 0, len(currentAddress)) for _, p := range currentAddress { if p.Type == AddressSegmentRunnable { nodePath = []string{} continue } nodePath = append(nodePath, p.ID) } return NewNodePath(nodePath...), len(nodePath) > 0 } // AppendAddressSegment creates a new execution context for a sub-component (e.g., a graph node or a tool call). // // It extends the current context's address with a new segment and populates the new context with the // appropriate interrupt state and resume data for that specific sub-address. // // - ctx: The parent context, typically the one passed into the component's Invoke/Stream method. // - segType: The type of the new address segment (e.g., "node", "tool"). // - segID: The unique ID for the new address segment. func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID string) context.Context { return core.AppendAddressSegment(ctx, segType, segID, "") } func appendToolAddressSegment(ctx context.Context, segID string, subID string) context.Context { return core.AppendAddressSegment(ctx, AddressSegmentTool, segID, subID) } ================================================ FILE: compose/resume_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "encoding/json" "sync" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) type myInterruptState struct { OriginalInput string } type myResumeData struct { Message string } type resumeTestState struct { OnStartCalledOnResume bool `json:"on_start_called_on_resume"` Counter int `json:"counter"` } func init() { schema.Register[resumeTestState]() } func TestInterruptStateAndResumeForRootGraph(t *testing.T) { // create a graph with a lambda node // this lambda node will interrupt with a typed state and an info for end-user // verify the info thrown by the lambda node // resume with a structured resume data // within the lambda node, getRunCtx and verify the state and resume data g := NewGraph[string, string]() lambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { wasInterrupted, hasState, state := GetInterruptState[*myInterruptState](ctx) if !wasInterrupted { // First run: interrupt with state return "", StatefulInterrupt(ctx, map[string]any{"reason": "scheduled maintenance"}, &myInterruptState{OriginalInput: input}, ) } // This is a resumed run. assert.True(t, hasState) assert.Equal(t, "initial input", state.OriginalInput) isResume, hasData, data := GetResumeContext[*myResumeData](ctx) assert.True(t, isResume) assert.True(t, hasData) assert.Equal(t, "let's continue", data.Message) return "Resumed successfully with input: " + state.OriginalInput, nil }) _ = g.AddLambdaNode("lambda", lambda) _ = g.AddEdge(START, "lambda") _ = g.AddEdge("lambda", END) graph, err := g.Compile(context.Background(), WithCheckPointStore(newInMemoryStore()), WithGraphName("root")) assert.NoError(t, err) // First invocation, which should be interrupted checkPointID := "test-checkpoint-1" _, err = graph.Invoke(context.Background(), "initial input", WithCheckPointID(checkPointID)) // Verify the interrupt error and extracted info assert.Error(t, err) interruptInfo, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt) assert.NotNil(t, interruptInfo) interruptContexts := interruptInfo.InterruptContexts assert.Equal(t, 1, len(interruptContexts)) assert.Equal(t, "runnable:root;node:lambda", interruptContexts[0].Address.String()) assert.Equal(t, map[string]any{"reason": "scheduled maintenance"}, interruptContexts[0].Info) // Prepare resume data ctx := ResumeWithData(context.Background(), interruptContexts[0].ID, &myResumeData{Message: "let's continue"}) // Resume execution output, err := graph.Invoke(ctx, "", WithCheckPointID(checkPointID)) // Verify the final result assert.NoError(t, err) assert.Equal(t, "Resumed successfully with input: initial input", output) } func TestProcessStateInOnStartDuringResume(t *testing.T) { graphOnStartCallCount := 0 processStateErrorOnResume := error(nil) cb := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { if info.Name == "test-process-state-onstart" { graphOnStartCallCount++ err := ProcessState[*resumeTestState](ctx, func(ctx context.Context, s *resumeTestState) error { s.Counter++ return nil }) if graphOnStartCallCount > 1 { processStateErrorOnResume = err } } return ctx }). Build() g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) *resumeTestState { return &resumeTestState{} })) lambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { wasInterrupted, _, _ := GetInterruptState[*myInterruptState](ctx) if !wasInterrupted { return "", StatefulInterrupt(ctx, map[string]any{"reason": "test interrupt"}, &myInterruptState{OriginalInput: input}, ) } var stateCounter int err := ProcessState[*resumeTestState](ctx, func(ctx context.Context, s *resumeTestState) error { stateCounter = s.Counter return nil }) assert.NoError(t, err) assert.Equal(t, 2, stateCounter, "Counter should be 2 (first run OnStart + resume OnStart)") return "success", nil }) _ = g.AddLambdaNode("lambda", lambda) _ = g.AddEdge(START, "lambda") _ = g.AddEdge("lambda", END) graph, err := g.Compile(context.Background(), WithCheckPointStore(newInMemoryStore()), WithGraphName("test-process-state-onstart"), ) assert.NoError(t, err) checkPointID := "test-checkpoint-process-state" _, err = graph.Invoke(context.Background(), "test input", WithCheckPointID(checkPointID), WithCallbacks(cb)) assert.Error(t, err, "First invocation should return an error") interruptInfo, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt, "Should be an interrupt error") assert.NotNil(t, interruptInfo) assert.Equal(t, 1, graphOnStartCallCount, "Graph OnStart should be called once on first run") ctx := ResumeWithData(context.Background(), interruptInfo.InterruptContexts[0].ID, &myResumeData{}) output, err := graph.Invoke(ctx, "", WithCheckPointID(checkPointID), WithCallbacks(cb)) assert.NoError(t, err) assert.Equal(t, "success", output) assert.Equal(t, 2, graphOnStartCallCount, "Graph OnStart should be called twice (first run + resume)") assert.NoError(t, processStateErrorOnResume, "ProcessState should work in OnStart during resume") } func TestInterruptStateAndResumeForSubGraph(t *testing.T) { // create a graph // create a another graph with a lambda node, as this graph as a sub-graph of the previous graph // this lambda node will interrupt with a typed state and an info for end-user // verify the info thrown by the lambda node // resume with a structured resume data // within the lambda node, getRunCtx and verify the state and resume data subGraph := NewGraph[string, string]() lambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { wasInterrupted, hasState, state := GetInterruptState[*myInterruptState](ctx) if !wasInterrupted { // First run: interrupt with state return "", StatefulInterrupt(ctx, map[string]any{"reason": "sub-graph maintenance"}, &myInterruptState{OriginalInput: input}, ) } // Second (resumed) run assert.True(t, hasState) assert.Equal(t, "main input", state.OriginalInput) isResume, hasData, data := GetResumeContext[*myResumeData](ctx) assert.True(t, isResume) assert.True(t, hasData) assert.Equal(t, "let's continue sub-graph", data.Message) return "Sub-graph resumed successfully", nil }) _ = subGraph.AddLambdaNode("inner_lambda", lambda) _ = subGraph.AddEdge(START, "inner_lambda") _ = subGraph.AddEdge("inner_lambda", END) // Create the main graph mainGraph := NewGraph[string, string]() _ = mainGraph.AddGraphNode("sub_graph_node", subGraph) _ = mainGraph.AddEdge(START, "sub_graph_node") _ = mainGraph.AddEdge("sub_graph_node", END) compiledMainGraph, err := mainGraph.Compile(context.Background(), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // First invocation, which should be interrupted checkPointID := "test-subgraph-checkpoint-1" _, err = compiledMainGraph.Invoke(context.Background(), "main input", WithCheckPointID(checkPointID)) // Verify the interrupt error and extracted info assert.Error(t, err) interruptInfo, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt) assert.NotNil(t, interruptInfo) interruptContexts := interruptInfo.InterruptContexts assert.Equal(t, 1, len(interruptContexts)) assert.Equal(t, "runnable:;node:sub_graph_node;node:inner_lambda", interruptContexts[0].Address.String()) assert.Equal(t, map[string]any{"reason": "sub-graph maintenance"}, interruptContexts[0].Info) // Prepare resume data ctx := ResumeWithData(context.Background(), interruptContexts[0].ID, &myResumeData{Message: "let's continue sub-graph"}) // Resume execution output, err := compiledMainGraph.Invoke(ctx, "", WithCheckPointID(checkPointID)) // Verify the final result assert.NoError(t, err) assert.Equal(t, "Sub-graph resumed successfully", output) } func TestInterruptStateAndResumeForToolInNestedSubGraph(t *testing.T) { // create a ROOT graph. // create a sub graph A, add A to ROOT graph using AddGraphNode. // create a sub-sub graph B, add B to A using AddGraphNode. // within sub-sub graph B, add a ChatModelNode, which is a Mock chat model that implements the ToolCallingChatModel // interface. // add a Mock InvokableTool to this mock chat model. // within sub-sub graph B, also add a ToolsNode that will execute this Mock InvokableTool. // this tool will interrupt with a typed state and an info for end-user // verify the info thrown by the tool. // resume with a structured resume data. // within the Tool, getRunCtx and verify the state and resume data ctrl := gomock.NewController(t) // 1. Define the interrupting tool mockTool := &mockInterruptingTool{tt: t} // 2. Define the sub-sub-graph (B) subSubGraphB := NewGraph[[]*schema.Message, []*schema.Message]() // Mock Chat Model that calls the tool mockChatModel := mockModel.NewMockToolCallingChatModel(ctrl) mockChatModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ {ID: "tool_call_123", Function: schema.FunctionCall{Name: "interrupt_tool", Arguments: `{"input": "test"}`}}, }, }, nil).AnyTimes() mockChatModel.EXPECT().WithTools(gomock.Any()).Return(mockChatModel, nil).AnyTimes() toolsNode, err := NewToolNode(context.Background(), &ToolsNodeConfig{Tools: []tool.BaseTool{mockTool}}) assert.NoError(t, err) _ = subSubGraphB.AddChatModelNode("model", mockChatModel) _ = subSubGraphB.AddToolsNode("tools", toolsNode) _ = subSubGraphB.AddEdge(START, "model") _ = subSubGraphB.AddEdge("model", "tools") _ = subSubGraphB.AddEdge("tools", END) // 3. Define sub-graph (A) subGraphA := NewGraph[[]*schema.Message, []*schema.Message]() _ = subGraphA.AddGraphNode("sub_graph_b", subSubGraphB) _ = subGraphA.AddEdge(START, "sub_graph_b") _ = subGraphA.AddEdge("sub_graph_b", END) // 4. Define root graph rootGraph := NewGraph[[]*schema.Message, []*schema.Message]() _ = rootGraph.AddGraphNode("sub_graph_a", subGraphA) _ = rootGraph.AddEdge(START, "sub_graph_a") _ = rootGraph.AddEdge("sub_graph_a", END) // 5. Compile and run compiledRootGraph, err := rootGraph.Compile(context.Background(), WithCheckPointStore(newInMemoryStore()), WithGraphName("root")) assert.NoError(t, err) // First invocation - should interrupt checkPointID := "test-nested-tool-interrupt" initialInput := []*schema.Message{schema.UserMessage("hello")} _, err = compiledRootGraph.Invoke(context.Background(), initialInput, WithCheckPointID(checkPointID)) // 6. Verify the interrupt assert.Error(t, err) interruptInfo, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt) assert.NotNil(t, interruptInfo) interruptContexts := interruptInfo.InterruptContexts assert.Len(t, interruptContexts, 1) // Only the root cause is returned // Verify the root cause context rootCause := interruptContexts[0] expectedPath := "runnable:root;node:sub_graph_a;node:sub_graph_b;node:tools;tool:interrupt_tool:tool_call_123" assert.Equal(t, expectedPath, rootCause.Address.String()) assert.True(t, rootCause.IsRootCause) assert.Equal(t, map[string]any{"reason": "tool maintenance"}, rootCause.Info) // Verify the parent via the Parent field assert.NotNil(t, rootCause.Parent) assert.Equal(t, "runnable:root;node:sub_graph_a;node:sub_graph_b;node:tools", rootCause.Parent.Address.String()) assert.False(t, rootCause.Parent.IsRootCause) // 7. Resume execution ctx := ResumeWithData(context.Background(), rootCause.ID, &myResumeData{Message: "let's continue tool"}) output, err := compiledRootGraph.Invoke(ctx, initialInput, WithCheckPointID(checkPointID)) // 8. Verify final result assert.NoError(t, err) assert.NotNil(t, output) assert.Len(t, output, 1) assert.Equal(t, "Tool resumed successfully", output[0].Content) } const PathSegmentTypeProcess AddressSegmentType = "process" // processState is the state for a single sub-process in the batch test. type processState struct { Step int } // batchState is the composite state for the whole batch lambda. type batchState struct { ProcessStates map[string]*processState Results map[string]string } type processResumeData struct { Instruction string } func init() { schema.RegisterName[*myInterruptState]("my_interrupt_state") schema.RegisterName[*batchState]("batch_state") schema.RegisterName[*processState]("process_state") } func TestMultipleInterruptsAndResumes(t *testing.T) { // define a new lambda node that act as a 'batch' node // it kick starts 3 parallel processes, each will interrupt on first run, while preserving their own state. // each of the process should have their own user-facing interrupt info. // define a new AddressSegmentType for these sub processes. // the lambda should use StatefulInterrupt to interrupt and preserve the state, // which is a specific struct type that implements the CompositeInterruptState interface. // there should also be a specific struct that that implements the CompositeInterruptInfo interface, // which helps the end-user to fetch the nested interrupt info. // put this lambda node within a graph and invoke the graph. // simulate the user getting the flat list of 3 interrupt points using GetInterruptContexts // the user then decides to resume two of the three interrupt points // the first resume has resume data, while the second resume does not.(ResumeWithData vs. Resume) // verify the resume data and state for the resumed interrupt points. processIDs := []string{"p0", "p1", "p2"} // This is the logic for a single "process" runProcess := func(ctx context.Context, id string) (string, error) { // Check if this specific process was interrupted before wasInterrupted, hasState, pState := GetInterruptState[*processState](ctx) if !wasInterrupted { // First run for this process, interrupt it. return "", StatefulInterrupt(ctx, map[string]any{"reason": "process " + id + " needs input"}, &processState{Step: 1}, ) } assert.True(t, hasState) assert.Equal(t, 1, pState.Step) // Check if we are being resumed isResume, hasData, pData := GetResumeContext[*processResumeData](ctx) if !isResume { // Not being resumed, so interrupt again. return "", StatefulInterrupt(ctx, map[string]any{"reason": "process " + id + " still needs input"}, pState, ) } // We are being resumed. if hasData { // Resumed with data return "process " + id + " done with instruction: " + pData.Instruction, nil } // Resumed without data return "process " + id + " done", nil } // This is the main "batch" lambda that orchestrates the processes batchLambda := InvokableLambda(func(ctx context.Context, _ string) (map[string]string, error) { // Restore the state of the batch node itself _, _, persistedBatchState := GetInterruptState[*batchState](ctx) if persistedBatchState == nil { persistedBatchState = &batchState{ Results: make(map[string]string), } } var errs []error for _, id := range processIDs { // If this process already completed in a previous run, skip it. if _, done := persistedBatchState.Results[id]; done { continue } // Create a sub-context for each process subCtx := AppendAddressSegment(ctx, PathSegmentTypeProcess, id) res, err := runProcess(subCtx, id) if err != nil { _, ok := IsInterruptRerunError(err) assert.True(t, ok) errs = append(errs, err) } else { // Process completed, save its result to the state for the next run. persistedBatchState.Results[id] = res } } if len(errs) > 0 { return nil, CompositeInterrupt(ctx, nil, persistedBatchState, errs...) } return persistedBatchState.Results, nil }) g := NewGraph[string, map[string]string]() _ = g.AddLambdaNode("batch", batchLambda) _ = g.AddEdge(START, "batch") _ = g.AddEdge("batch", END) graph, err := g.Compile(context.Background(), WithCheckPointStore(newInMemoryStore()), WithGraphName("root")) assert.NoError(t, err) // --- 1. First invocation, all 3 processes should interrupt --- checkPointID := "multi-interrupt-test" _, err = graph.Invoke(context.Background(), "", WithCheckPointID(checkPointID)) assert.Error(t, err) interruptInfo, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt) interruptContexts := interruptInfo.InterruptContexts assert.Len(t, interruptContexts, 3) // Only the 3 root causes found := make(map[string]bool) addrToID := make(map[string]string) var parentCtx *InterruptCtx for _, iCtx := range interruptContexts { addrStr := iCtx.Address.String() found[addrStr] = true addrToID[addrStr] = iCtx.ID assert.True(t, iCtx.IsRootCause) assert.Equal(t, map[string]any{"reason": "process " + iCtx.Address[2].ID + " needs input"}, iCtx.Info) // Check that all share the same parent assert.NotNil(t, iCtx.Parent) if parentCtx == nil { parentCtx = iCtx.Parent assert.Equal(t, "runnable:root;node:batch", parentCtx.Address.String()) assert.False(t, parentCtx.IsRootCause) } else { assert.Same(t, parentCtx, iCtx.Parent) } } assert.True(t, found["runnable:root;node:batch;process:p0"]) assert.True(t, found["runnable:root;node:batch;process:p1"]) assert.True(t, found["runnable:root;node:batch;process:p2"]) // --- 2. Second invocation, resume 2 of 3 processes --- // Resume p0 with data, and p2 without data. p1 remains interrupted. resumeCtx := ResumeWithData(context.Background(), addrToID["runnable:root;node:batch;process:p0"], &processResumeData{Instruction: "do it"}) resumeCtx = Resume(resumeCtx, addrToID["runnable:root;node:batch;process:p2"]) _, err = graph.Invoke(resumeCtx, "", WithCheckPointID(checkPointID)) // Expect an interrupt again, but only for p1 assert.Error(t, err) interruptInfo2, isInterrupt2 := ExtractInterruptInfo(err) assert.True(t, isInterrupt2) interruptContexts2 := interruptInfo2.InterruptContexts assert.Len(t, interruptContexts2, 1) // Only p1 is left rootCause2 := interruptContexts2[0] assert.Equal(t, "runnable:root;node:batch;process:p1", rootCause2.Address.String()) assert.NotNil(t, rootCause2.Parent) assert.Equal(t, "runnable:root;node:batch", rootCause2.Parent.Address.String()) // --- 3. Third invocation, resume the last process --- finalResumeCtx := Resume(context.Background(), rootCause2.ID) finalOutput, err := graph.Invoke(finalResumeCtx, "", WithCheckPointID(checkPointID)) assert.NoError(t, err) assert.Equal(t, "process p0 done with instruction: do it", finalOutput["p0"]) assert.Equal(t, "process p1 done", finalOutput["p1"]) assert.Equal(t, "process p2 done", finalOutput["p2"]) } // toolsNodeResumeTargetCallback captures isResumeTarget for ToolsNode during OnStart type toolsNodeResumeTargetCallback struct { mu sync.Mutex isResumeTargetLog []bool } func (c *toolsNodeResumeTargetCallback) OnStart(ctx context.Context, info *callbacks.RunInfo, _ callbacks.CallbackInput) context.Context { if info.Component == ComponentOfToolsNode { isResumeTarget, _, _ := GetResumeContext[any](ctx) c.mu.Lock() c.isResumeTargetLog = append(c.isResumeTargetLog, isResumeTarget) c.mu.Unlock() } return ctx } func (c *toolsNodeResumeTargetCallback) OnEnd(ctx context.Context, _ *callbacks.RunInfo, _ callbacks.CallbackOutput) context.Context { return ctx } func (c *toolsNodeResumeTargetCallback) OnError(ctx context.Context, _ *callbacks.RunInfo, _ error) context.Context { return ctx } func (c *toolsNodeResumeTargetCallback) OnStartWithStreamInput(ctx context.Context, _ *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { input.Close() return ctx } func (c *toolsNodeResumeTargetCallback) OnEndWithStreamOutput(ctx context.Context, _ *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { output.Close() return ctx } // mockReentryTool is a helper for the reentry test type mockReentryTool struct { t *testing.T mu sync.Mutex isResumeTargetByRunID map[string]bool } func (t *mockReentryTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "reentry_tool", Desc: "A tool that can be re-entered in a resumed graph.", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{"input": {Type: schema.String}}), }, nil } func (t *mockReentryTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { wasInterrupted, hasState, _ := tool.GetInterruptState[any](ctx) isResume, hasData, data := tool.GetResumeContext[*myResumeData](ctx) callID := GetToolCallID(ctx) t.mu.Lock() if t.isResumeTargetByRunID != nil { t.isResumeTargetByRunID[callID] = isResume } t.mu.Unlock() // Special handling for the re-entrant call to make assertions explicit. if callID == "call_3" { if !isResume { // This is the first run of the re-entrant call. Its context must be clean. // This is the core assertion for this test. assert.False(t.t, wasInterrupted, "re-entrant call 'call_3' should not have been interrupted on its first run") assert.False(t.t, hasState, "re-entrant call 'call_3' should not have state on its first run") // Now, interrupt it as part of the test flow. return "", tool.StatefulInterrupt(ctx, nil, "some state for "+callID) } // This is the resumed run of the re-entrant call. assert.True(t.t, wasInterrupted, "resumed call 'call_3' must have been interrupted") assert.True(t.t, hasData, "resumed call 'call_3' should have data") return "Resumed " + data.Message, nil } // Standard logic for the initial calls (call_1, call_2) if !wasInterrupted { // First run for call_1 and call_2, should interrupt. return "", tool.StatefulInterrupt(ctx, nil, "some state for "+callID) } // From here, wasInterrupted is true for call_1 and call_2. if isResume { // The user is explicitly resuming this call. assert.True(t.t, hasData, "call %s should have resume data", callID) return "Resumed " + data.Message, nil } // The tool was interrupted before, but is not being resumed now. Re-interrupt. return "", tool.StatefulInterrupt(ctx, nil, "some state for "+callID) } func TestReentryForResumedTools(t *testing.T) { // create a 'ReAct' style graph with a ChatModel node and a ToolsNode. // within the ToolsNode there is an interruptible tool that will emit interrupt on first run. // During the first invocation of the graph, there should be two tool calls (of the same tool) that interrupt. // The user chooses to resume one of the interrupted tool call in second invocation, // and this time, the resumed tool call should be successful, while the other should interrupt immediately again. // The user then chooses to resume the other interrupted tool call in third invocation, // and this time, the ChatModel decides to call the tool again, // and this time the tool's runCtx should think it was not interrupted nor resumed. ctrl := gomock.NewController(t) // 1. Define the interrupting tool and callback reentryTool := &mockReentryTool{t: t, isResumeTargetByRunID: make(map[string]bool)} toolsNodeCB := &toolsNodeResumeTargetCallback{} // 2. Define the graph g := NewGraph[[]*schema.Message, *schema.Message]() // Mock Chat Model that drives the ReAct loop mockChatModel := mockModel.NewMockToolCallingChatModel(ctrl) toolsNode, err := NewToolNode(context.Background(), &ToolsNodeConfig{Tools: []tool.BaseTool{reentryTool}}) assert.NoError(t, err) // Expectation for the 1st invocation: model returns two tool calls mockChatModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ {ID: "call_1", Function: schema.FunctionCall{Name: "reentry_tool", Arguments: `{"input": "a"}`}}, {ID: "call_2", Function: schema.FunctionCall{Name: "reentry_tool", Arguments: `{"input": "b"}`}}, }, }, nil).Times(1) // Expectation for the 2nd invocation (after resuming call_1): model does nothing, graph continues // Expectation for the 3rd invocation (after resuming call_2): model calls the tool again mockChatModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { return &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ {ID: "call_3", Function: schema.FunctionCall{Name: "reentry_tool", Arguments: `{"input": "c"}`}}, }, }, nil }).Times(1) // Expectation for the final invocation: model returns final answer mockChatModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&schema.Message{ Role: schema.Assistant, Content: "all done", }, nil).Times(1) _ = g.AddChatModelNode("model", mockChatModel) _ = g.AddToolsNode("tools", toolsNode) _ = g.AddEdge(START, "model") // Add the crucial branch to decide whether to call tools or end. modelBranch := func(ctx context.Context, msg *schema.Message) (string, error) { if len(msg.ToolCalls) > 0 { return "tools", nil } return END, nil } err = g.AddBranch("model", NewGraphBranch(modelBranch, map[string]bool{"tools": true, END: true})) assert.NoError(t, err) _ = g.AddEdge("tools", "model") // Loop back for ReAct style // 3. Compile and run graph, err := g.Compile(context.Background(), WithCheckPointStore(newInMemoryStore()), WithGraphName("root")) assert.NoError(t, err) checkPointID := "reentry-test" // --- 1. First invocation: call_1 and call_2 should interrupt --- _, err = graph.Invoke(context.Background(), []*schema.Message{schema.UserMessage("start")}, WithCheckPointID(checkPointID), WithCallbacks(toolsNodeCB)) assert.Error(t, err) interruptInfo1, _ := ExtractInterruptInfo(err) interrupts1 := interruptInfo1.InterruptContexts assert.Len(t, interrupts1, 2) // Only the two tool calls found1 := make(map[string]bool) addrToID1 := make(map[string]string) for _, iCtx := range interrupts1 { addrStr := iCtx.Address.String() found1[addrStr] = true addrToID1[addrStr] = iCtx.ID assert.True(t, iCtx.IsRootCause) assert.NotNil(t, iCtx.Parent) assert.Equal(t, "runnable:root;node:tools", iCtx.Parent.Address.String()) } assert.True(t, found1["runnable:root;node:tools;tool:reentry_tool:call_1"]) assert.True(t, found1["runnable:root;node:tools;tool:reentry_tool:call_2"]) // First invocation: neither call_1 nor call_2 should be resume targets assert.False(t, reentryTool.isResumeTargetByRunID["call_1"], "first run: call_1 should not be resume target") assert.False(t, reentryTool.isResumeTargetByRunID["call_2"], "first run: call_2 should not be resume target") // First invocation: ToolsNode should NOT be a resume target assert.Len(t, toolsNodeCB.isResumeTargetLog, 1, "ToolsNode OnStart should be called once in first invocation") assert.False(t, toolsNodeCB.isResumeTargetLog[0], "first run: ToolsNode should NOT be resume target") // Clear for next invocation reentryTool.isResumeTargetByRunID = make(map[string]bool) toolsNodeCB.isResumeTargetLog = nil // --- 2. Second invocation: resume call_1, expect call_2 to interrupt again --- resumeCtx2 := ResumeWithData(context.Background(), addrToID1["runnable:root;node:tools;tool:reentry_tool:call_1"], &myResumeData{Message: "resume call 1"}) _, err = graph.Invoke(resumeCtx2, []*schema.Message{schema.UserMessage("start")}, WithCheckPointID(checkPointID), WithCallbacks(toolsNodeCB)) assert.Error(t, err) interruptInfo2, _ := ExtractInterruptInfo(err) interrupts2 := interruptInfo2.InterruptContexts assert.Len(t, interrupts2, 1) // Only call_2 rootCause2 := interrupts2[0] assert.Equal(t, "runnable:root;node:tools;tool:reentry_tool:call_2", rootCause2.Address.String()) assert.NotNil(t, rootCause2.Parent) assert.Equal(t, "runnable:root;node:tools", rootCause2.Parent.Address.String()) // Second invocation: call_1 is resumed, call_2 is NOT resumed (re-interrupts) assert.True(t, reentryTool.isResumeTargetByRunID["call_1"], "second run: call_1 should be resume target") assert.False(t, reentryTool.isResumeTargetByRunID["call_2"], "second run: call_2 should NOT be resume target (it re-interrupts)") // Second invocation: ToolsNode SHOULD be a resume target (because call_1 child is being resumed) assert.Len(t, toolsNodeCB.isResumeTargetLog, 1, "ToolsNode OnStart should be called once in second invocation") assert.True(t, toolsNodeCB.isResumeTargetLog[0], "second run: ToolsNode SHOULD be resume target (child call_1 is being resumed)") // Clear for next invocation reentryTool.isResumeTargetByRunID = make(map[string]bool) toolsNodeCB.isResumeTargetLog = nil // --- 3. Third invocation: resume call_2, model makes a new call (call_3) which should interrupt --- resumeCtx3 := ResumeWithData(context.Background(), rootCause2.ID, &myResumeData{Message: "resume call 2"}) _, err = graph.Invoke(resumeCtx3, []*schema.Message{schema.UserMessage("start")}, WithCheckPointID(checkPointID), WithCallbacks(toolsNodeCB)) assert.Error(t, err) interruptInfo3, _ := ExtractInterruptInfo(err) interrupts3 := interruptInfo3.InterruptContexts assert.Len(t, interrupts3, 1) // Only call_3 rootCause3 := interrupts3[0] assert.Equal(t, "runnable:root;node:tools;tool:reentry_tool:call_3", rootCause3.Address.String()) // Note: this is the new call_3 assert.NotNil(t, rootCause3.Parent) assert.Equal(t, "runnable:root;node:tools", rootCause3.Parent.Address.String()) // Third invocation: call_2 is resumed, call_3 is new (not resumed) assert.True(t, reentryTool.isResumeTargetByRunID["call_2"], "third run: call_2 should be resume target") assert.False(t, reentryTool.isResumeTargetByRunID["call_3"], "third run: call_3 should NOT be resume target (it's new)") // Third invocation: ToolsNode is called twice (once for call_2 resume, once for call_3 new) // First call: ToolsNode SHOULD be resume target (call_2 is being resumed) // Second call: ToolsNode should NOT be resume target (call_3 is new, no children to resume) assert.Len(t, toolsNodeCB.isResumeTargetLog, 2, "ToolsNode OnStart should be called twice in third invocation") assert.True(t, toolsNodeCB.isResumeTargetLog[0], "third run first ToolsNode call: SHOULD be resume target (child call_2 is being resumed)") assert.False(t, toolsNodeCB.isResumeTargetLog[1], "third run second ToolsNode call: should NOT be resume target (call_3 is new)") // Clear for next invocation reentryTool.isResumeTargetByRunID = make(map[string]bool) toolsNodeCB.isResumeTargetLog = nil // --- 4. Final invocation: resume call_3, expect final answer --- resumeCtx4 := ResumeWithData(context.Background(), rootCause3.ID, &myResumeData{Message: "resume call 3"}) output, err := graph.Invoke(resumeCtx4, []*schema.Message{schema.UserMessage("start")}, WithCheckPointID(checkPointID), WithCallbacks(toolsNodeCB)) assert.NoError(t, err) assert.Equal(t, "all done", output.Content) // Fourth invocation: call_3 is resumed assert.True(t, reentryTool.isResumeTargetByRunID["call_3"], "fourth run: call_3 should be resume target") // Fourth invocation: ToolsNode SHOULD be resume target (call_3 is being resumed) assert.Len(t, toolsNodeCB.isResumeTargetLog, 1, "ToolsNode OnStart should be called once in fourth invocation") assert.True(t, toolsNodeCB.isResumeTargetLog[0], "fourth run: ToolsNode SHOULD be resume target (child call_3 is being resumed)") } // mockInterruptingTool is a helper for the nested tool interrupt test type mockInterruptingTool struct { tt *testing.T } func (t *mockInterruptingTool) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "interrupt_tool", Desc: "A tool that interrupts execution.", ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "input": {Type: schema.String, Desc: "Some input", Required: true}, }), }, nil } func (t *mockInterruptingTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { var args map[string]string _ = json.Unmarshal([]byte(argumentsInJSON), &args) wasInterrupted, hasState, state := tool.GetInterruptState[*myInterruptState](ctx) if !wasInterrupted { // First run: interrupt return "", tool.StatefulInterrupt(ctx, map[string]any{"reason": "tool maintenance"}, &myInterruptState{OriginalInput: args["input"]}, ) } // Second (resumed) run assert.True(t.tt, hasState) assert.Equal(t.tt, "test", state.OriginalInput) isResume, hasData, data := tool.GetResumeContext[*myResumeData](ctx) assert.True(t.tt, isResume) assert.True(t.tt, hasData) assert.Equal(t.tt, "let's continue tool", data.Message) return "Tool resumed successfully", nil } func TestGraphInterruptWithinLambda(t *testing.T) { // this test case aims to verify behaviors when a standalone graph is within a lambda, // which in turn is within the root graph. // the expected behavior is: // - internal graph will naturally append to the Address // - internal graph interrupts, where the Address includes steps for both the root graph and the internal graph // - lambda extracts InterruptInfo, then GetInterruptContexts // - lambda then acts as a composite node, uses CompositeInterrupt to pass up the // internal interrupt points // - the root graph interrupts // - end-user extracts the interrupt ID and related info // - end-user uses ResumeWithData to resume the ID // - lambda node resumes, invokes the inner graph as usual // - the internal graph resumes the interrupted node // To implement this test, within the internal graph you can define another lambda node that can interrupt resume. // 1. Define the innermost lambda that actually interrupts interruptingLambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { wasInterrupted, hasState, state := GetInterruptState[*myInterruptState](ctx) if !wasInterrupted { return "", StatefulInterrupt(ctx, "inner interrupt", &myInterruptState{OriginalInput: input}) } assert.True(t, hasState) assert.Equal(t, "top level input", state.OriginalInput) isResume, hasData, data := GetResumeContext[*myResumeData](ctx) assert.True(t, isResume) assert.True(t, hasData) assert.Equal(t, "resume inner", data.Message) return "inner lambda resumed successfully", nil }) // 2. Define the internal graph that contains the interrupting lambda innerGraph := NewGraph[string, string]() _ = innerGraph.AddLambdaNode("inner_lambda", interruptingLambda) _ = innerGraph.AddEdge(START, "inner_lambda") _ = innerGraph.AddEdge("inner_lambda", END) // Give the inner graph a name so it can create its "runnable" addr step. compiledInnerGraph, err := innerGraph.Compile(context.Background(), WithGraphName("inner"), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // 3. Define the outer lambda that acts as a composite node compositeLambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { // The lambda invokes the inner graph. If the inner graph interrupts, this lambda // must act as a proper composite node and wrap the error. output, err := compiledInnerGraph.Invoke(ctx, input, WithCheckPointID("inner-cp")) if err != nil { _, isInterrupt := ExtractInterruptInfo(err) if !isInterrupt { return "", err // Not an interrupt, just fail } // The composite interrupt itself can be stateless, as it's just a wrapper. // It signals to the framework to look inside the subErrs and correctly // prepend the current addr to the paths of the inner interrupts. return "", CompositeInterrupt(ctx, "composite interrupt from lambda", nil, err) } return output, nil }) // 4. Define the root graph rootGraph := NewGraph[string, string]() _ = rootGraph.AddLambdaNode("composite_lambda", compositeLambda) _ = rootGraph.AddEdge(START, "composite_lambda") _ = rootGraph.AddEdge("composite_lambda", END) // Give the root graph a name for its "runnable" addr step. compiledRootGraph, err := rootGraph.Compile(context.Background(), WithGraphName("root"), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // 5. First invocation - should interrupt checkPointID := "graph-in-lambda-test" _, err = compiledRootGraph.Invoke(context.Background(), "top level input", WithCheckPointID(checkPointID)) // 6. Verify the interrupt assert.Error(t, err) interruptInfo, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt) interruptContexts := interruptInfo.InterruptContexts assert.Len(t, interruptContexts, 1) // Only the root cause is returned // The addr is now fully qualified, including the runnable steps from both graphs. rootCause := interruptContexts[0] expectedPath := "runnable:root;node:composite_lambda;runnable:inner;node:inner_lambda" assert.Equal(t, expectedPath, rootCause.Address.String()) assert.Equal(t, "inner interrupt", rootCause.Info) assert.True(t, rootCause.IsRootCause) // Check parent hierarchy assert.NotNil(t, rootCause.Parent) assert.Equal(t, "runnable:root;node:composite_lambda;runnable:inner", rootCause.Parent.Address.String()) assert.Nil(t, rootCause.Parent.Info) // The inner runnable doesn't have its own info assert.False(t, rootCause.Parent.IsRootCause) // Check grandparent assert.NotNil(t, rootCause.Parent.Parent) assert.Equal(t, "runnable:root;node:composite_lambda", rootCause.Parent.Parent.Address.String()) assert.Equal(t, "composite interrupt from lambda", rootCause.Parent.Parent.Info) assert.False(t, rootCause.Parent.Parent.IsRootCause) // 7. Resume execution using the complete, fully-qualified ID resumeCtx := ResumeWithData(context.Background(), rootCause.ID, &myResumeData{Message: "resume inner"}) finalOutput, err := compiledRootGraph.Invoke(resumeCtx, "top level input", WithCheckPointID(checkPointID)) // 8. Verify final result assert.NoError(t, err) assert.Equal(t, "inner lambda resumed successfully", finalOutput) } func TestLegacyInterrupt(t *testing.T) { // this test case aims to test the behavior of the deprecated InterruptAndRerun, // NewInterruptAndRerunErr within CompositeInterrupt. // Define two sub-processes(functions), one interrupts with InterruptAndRerun, // the other interrupts with NewInterruptAndRerunErr. // create a lambda as a composite node, within the lambda invokes the two sub-processes. // create the graph, add lambda node and invoke it. // after verifying the interrupt points, just invokes again without explicit resume. // verify the same interrupt IDs again. // then finally Resume() the graph. // 1. Define the sub-processes that use legacy and modern interrupts subProcess1 := func(ctx context.Context) (string, error) { isResume, _, data := GetResumeContext[string](ctx) if isResume { return data, nil } return "", deprecatedInterruptAndRerun } subProcess2 := func(ctx context.Context) (string, error) { isResume, _, data := GetResumeContext[string](ctx) if isResume { return data, nil } return "", deprecatedInterruptAndRerunErr("legacy info") } subProcess3 := func(ctx context.Context) (string, error) { isResume, _, data := GetResumeContext[string](ctx) if isResume { return data, nil } // Use the modern, addr-aware interrupt function return "", Interrupt(ctx, "modern info") } // 2. Define the composite lambda compositeLambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { // If the lambda itself is being resumed, it means the whole process is done. isResume, _, data := GetResumeContext[string](ctx) // Run sub-processes and collect their errors var ( errs []error outStr string ) const PathStepCustom AddressSegmentType = "custom" subCtx1 := AppendAddressSegment(ctx, PathStepCustom, "1") out1, err1 := subProcess1(subCtx1) if err1 != nil { // Wrap the legacy error to give it a addr wrappedErr := WrapInterruptAndRerunIfNeeded(ctx, AddressSegment{Type: PathStepCustom, ID: "1"}, err1) errs = append(errs, wrappedErr) } else { outStr += out1 } subCtx2 := AppendAddressSegment(ctx, PathStepCustom, "2") out2, err2 := subProcess2(subCtx2) if err2 != nil { // Wrap the legacy error to give it a addr wrappedErr := WrapInterruptAndRerunIfNeeded(ctx, AddressSegment{Type: PathStepCustom, ID: "2"}, err2) errs = append(errs, wrappedErr) } else { outStr += out2 } subCtx3 := AppendAddressSegment(ctx, PathStepCustom, "3") out3, err3 := subProcess3(subCtx3) if err3 != nil { // The error from Interrupt() is already addr-aware. WrapInterruptAndRerunIfNeeded // should handle this gracefully and return the error as-is. wrappedErr := WrapInterruptAndRerunIfNeeded(ctx, AddressSegment{Type: PathStepCustom, ID: "3"}, err3) errs = append(errs, wrappedErr) } else { outStr += out3 } if len(errs) > 0 { // Return a composite interrupt containing the wrapped legacy errors return "", CompositeInterrupt(ctx, "legacy composite", nil, errs...) } if isResume { outStr = outStr + " " + data } return outStr, nil }) // 3. Create and compile the graph rootGraph := NewGraph[string, string]() _ = rootGraph.AddLambdaNode("legacy_composite", compositeLambda) _ = rootGraph.AddEdge(START, "legacy_composite") _ = rootGraph.AddEdge("legacy_composite", END) compiledGraph, err := rootGraph.Compile(context.Background(), WithGraphName("root"), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // 4. First invocation - should interrupt checkPointID := "legacy-interrupt-test" _, err = compiledGraph.Invoke(context.Background(), "input", WithCheckPointID(checkPointID)) // 5. Verify the three interrupt points assert.Error(t, err) info, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt) assert.Len(t, info.InterruptContexts, 3) // Only the 3 root causes found := make(map[string]any) addrToID := make(map[string]string) var parentCtx *InterruptCtx for _, iCtx := range info.InterruptContexts { addrStr := iCtx.Address.String() found[addrStr] = iCtx.Info addrToID[addrStr] = iCtx.ID assert.True(t, iCtx.IsRootCause) // Check parent assert.NotNil(t, iCtx.Parent) if parentCtx == nil { parentCtx = iCtx.Parent assert.Equal(t, "runnable:root;node:legacy_composite", parentCtx.Address.String()) assert.Equal(t, "legacy composite", parentCtx.Info) assert.False(t, parentCtx.IsRootCause) } else { assert.Same(t, parentCtx, iCtx.Parent) } } expectedID1 := "runnable:root;node:legacy_composite;custom:1" expectedID2 := "runnable:root;node:legacy_composite;custom:2" expectedID3 := "runnable:root;node:legacy_composite;custom:3" assert.Contains(t, found, expectedID1) assert.Nil(t, found[expectedID1]) // From InterruptAndRerun assert.Contains(t, found, expectedID2) assert.Equal(t, "legacy info", found[expectedID2]) // From NewInterruptAndRerunErr assert.Contains(t, found, expectedID3) assert.Equal(t, "modern info", found[expectedID3]) // From Interrupt // 6. Second invocation (re-run without resume) - should yield the same interrupts _, err = compiledGraph.Invoke(context.Background(), "input", WithCheckPointID(checkPointID)) assert.Error(t, err) info2, isInterrupt2 := ExtractInterruptInfo(err) assert.True(t, isInterrupt2) assert.Len(t, info2.InterruptContexts, 3, "Should have the same number of interrupts on re-run") // 7. Third invocation - Resume all three interrupt points with specific data resumeData := map[string]any{ addrToID[expectedID1]: "output1", addrToID[expectedID2]: "output2", addrToID[expectedID3]: "output3", } resumeCtx := BatchResumeWithData(context.Background(), resumeData) // TODO: The legacy interrupt wrapping does not currently work correctly with BatchResumeWithData. // The graph re-interrupts instead of completing. This should be fixed in the core framework. _, err = compiledGraph.Invoke(resumeCtx, "input", WithCheckPointID(checkPointID)) assert.Error(t, err) } type wrapperToolForTest struct { compiledGraph Runnable[string, string] isResumeTargetLog []bool } func (w *wrapperToolForTest) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "wrapperTool", Desc: "A tool that wraps a nested graph", }, nil } func (w *wrapperToolForTest) InvokableRun(ctx context.Context, input string, opts ...tool.Option) (string, error) { isResumeTarget, _, _ := tool.GetResumeContext[any](ctx) w.isResumeTargetLog = append(w.isResumeTargetLog, isResumeTarget) result, err := w.compiledGraph.Invoke(ctx, input) if err != nil { if _, ok := ExtractInterruptInfo(err); ok { return "", tool.CompositeInterrupt(ctx, "wrapper tool interrupt", nil, err) } return "", err } return result, nil } func TestToolCompositeInterruptWithNestedGraphInterrupt(t *testing.T) { ctx := context.Background() var innerNodeIsResumeTarget bool subSubGraph := NewGraph[string, string]() err := subSubGraph.AddLambdaNode("interruptNode", InvokableLambda(func(ctx context.Context, input string) (string, error) { wasInterrupted, _, _ := GetInterruptState[any](ctx) if !wasInterrupted { return "", Interrupt(ctx, "sub-sub graph interrupt info") } isResumeTarget, _, _ := GetResumeContext[any](ctx) innerNodeIsResumeTarget = isResumeTarget return "resumed successfully", nil })) assert.NoError(t, err) assert.NoError(t, subSubGraph.AddEdge(START, "interruptNode")) assert.NoError(t, subSubGraph.AddEdge("interruptNode", END)) nestedGraph := NewGraph[string, string]() err = nestedGraph.AddGraphNode("subSubGraph", subSubGraph) assert.NoError(t, err) assert.NoError(t, nestedGraph.AddEdge(START, "subSubGraph")) assert.NoError(t, nestedGraph.AddEdge("subSubGraph", END)) compiledNestedGraph, err := nestedGraph.Compile(ctx) assert.NoError(t, err) wrapperTool := &wrapperToolForTest{compiledGraph: compiledNestedGraph.(Runnable[string, string])} toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{Tools: []tool.BaseTool{wrapperTool}}) assert.NoError(t, err) outerGraph := NewGraph[*schema.Message, []*schema.Message]() err = outerGraph.AddToolsNode("tools", toolsNode) assert.NoError(t, err) assert.NoError(t, outerGraph.AddEdge(START, "tools")) assert.NoError(t, outerGraph.AddEdge("tools", END)) compiledOuterGraph, err := outerGraph.Compile(ctx, WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) checkpointID := "test-wrapper-tool-resume" inputMsg := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ {ID: "call_1", Function: schema.FunctionCall{Name: "wrapperTool", Arguments: `"test"`}}, }, } _, err = compiledOuterGraph.Invoke(ctx, inputMsg, WithCheckPointID(checkpointID)) assert.Error(t, err) info, ok := ExtractInterruptInfo(err) assert.True(t, ok, "should be an interrupt error") assert.NotNil(t, info) assert.NotEmpty(t, info.InterruptContexts) rootCause := info.InterruptContexts[0] assert.Equal(t, "sub-sub graph interrupt info", rootCause.Info) assert.True(t, rootCause.IsRootCause) var wrapperToolParent *InterruptCtx for p := rootCause.Parent; p != nil; p = p.Parent { if p.Info == "wrapper tool interrupt" { wrapperToolParent = p break } } assert.NotNil(t, wrapperToolParent, "should have parent from wrapper tool with info 'wrapper tool interrupt'") assert.Len(t, wrapperTool.isResumeTargetLog, 1) assert.False(t, wrapperTool.isResumeTargetLog[0], "first invocation: wrapper tool should not be resume target") resumeCtx := Resume(ctx, rootCause.ID) _, err = compiledOuterGraph.Invoke(resumeCtx, inputMsg, WithCheckPointID(checkpointID)) assert.NoError(t, err) assert.True(t, innerNodeIsResumeTarget, "inner node should be resume target") assert.Len(t, wrapperTool.isResumeTargetLog, 2) assert.True(t, wrapperTool.isResumeTargetLog[1], "second invocation: wrapper tool should be resume target because its child is targeted") } ================================================ FILE: compose/runnable.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "reflect" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) // Runnable is the interface for an executable object. Graph, Chain can be compiled into Runnable. // runnable is the core conception of eino, we do downgrade compatibility for four data flow patterns, // and can automatically connect components that only implement one or more methods. // eg, if a component only implements Stream() method, you can still call Invoke() to convert stream output to invoke output. type Runnable[I, O any] interface { Invoke(ctx context.Context, input I, opts ...Option) (output O, err error) Stream(ctx context.Context, input I, opts ...Option) (output *schema.StreamReader[O], err error) Collect(ctx context.Context, input *schema.StreamReader[I], opts ...Option) (output O, err error) Transform(ctx context.Context, input *schema.StreamReader[I], opts ...Option) (output *schema.StreamReader[O], err error) } type invoke func(ctx context.Context, input any, opts ...any) (output any, err error) type transform func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) // composableRunnable the wrapper for all executable object directly provided by the user. // one instance corresponds to one instance of the executable object. // all information comes from executable object without any other dimensions of information. // for the graphNode, ChainBranch, StatePreHandler, StatePostHandler etc. type composableRunnable struct { i invoke t transform inputType reflect.Type outputType reflect.Type optionType reflect.Type *genericHelper isPassthrough bool meta *executorMeta // only available when in Graph node // if composableRunnable not in Graph node, this field would be nil nodeInfo *nodeInfo } func runnableLambda[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], c Collect[I, O, TOption], t Transform[I, O, TOption], enableCallback bool) *composableRunnable { rp := newRunnablePacker(i, s, c, t, enableCallback) return rp.toComposableRunnable() } type runnablePacker[I, O, TOption any] struct { i Invoke[I, O, TOption] s Stream[I, O, TOption] c Collect[I, O, TOption] t Transform[I, O, TOption] } func (rp *runnablePacker[I, O, TOption]) wrapRunnableCtx(ctxWrapper func(ctx context.Context, opts ...TOption) context.Context) { i, s, c, t := rp.i, rp.s, rp.c, rp.t rp.i = func(ctx context.Context, input I, opts ...TOption) (output O, err error) { ctx = ctxWrapper(ctx, opts...) return i(ctx, input, opts...) } rp.s = func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { ctx = ctxWrapper(ctx, opts...) return s(ctx, input, opts...) } rp.c = func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { ctx = ctxWrapper(ctx, opts...) return c(ctx, input, opts...) } rp.t = func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { ctx = ctxWrapper(ctx, opts...) return t(ctx, input, opts...) } } func (rp *runnablePacker[I, O, TOption]) toComposableRunnable() *composableRunnable { inputType := generic.TypeOf[I]() outputType := generic.TypeOf[O]() optionType := generic.TypeOf[TOption]() c := &composableRunnable{ genericHelper: newGenericHelper[I, O](), inputType: inputType, outputType: outputType, optionType: optionType, } i := func(ctx context.Context, input any, opts ...any) (output any, err error) { in, ok := input.(I) if !ok { // When a nil is passed as an 'any' type, its original type information is lost, // becoming an untyped nil. This would cause type assertions to fail. // So if the input is nil and the target type I is an interface, we need to explicitly create a nil of type I. if input == nil && reflect.TypeOf((*I)(nil)).Elem().Kind() == reflect.Interface { var i I in = i } else { panic(newUnexpectedInputTypeErr(inputType, reflect.TypeOf(input))) } } tos, err := convertOption[TOption](opts...) if err != nil { return nil, err } return rp.Invoke(ctx, in, tos...) } t := func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { in, ok := unpackStreamReader[I](input) if !ok { panic(newUnexpectedInputTypeErr(reflect.TypeOf(in), input.getType())) } tos, err := convertOption[TOption](opts...) if err != nil { return nil, err } out, err := rp.Transform(ctx, in, tos...) if err != nil { return nil, err } return packStreamReader(out), nil } c.i = i c.t = t return c } // Invoke works like `ping => pong`. func (rp *runnablePacker[I, O, TOption]) Invoke(ctx context.Context, input I, opts ...TOption) (output O, err error) { return rp.i(ctx, input, opts...) } // Stream works like `ping => stream output`. func (rp *runnablePacker[I, O, TOption]) Stream(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { return rp.s(ctx, input, opts...) } // Collect works like `stream input => pong`. func (rp *runnablePacker[I, O, TOption]) Collect(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { return rp.c(ctx, input, opts...) } // Transform works like `stream input => stream output`. func (rp *runnablePacker[I, O, TOption]) Transform(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { return rp.t(ctx, input, opts...) } func defaultImplConcatStreamReader[T any]( sr *schema.StreamReader[T]) (T, error) { c, err := concatStreamReader(sr) if err != nil { var t T return t, fmt.Errorf("concat stream reader fail: %w", err) } return c, nil } func invokeByStream[I, O, TOption any](s Stream[I, O, TOption]) Invoke[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { sr, err := s(ctx, input, opts...) if err != nil { return output, err } return defaultImplConcatStreamReader(sr) } } func invokeByCollect[I, O, TOption any](c Collect[I, O, TOption]) Invoke[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { sr := schema.StreamReaderFromArray([]I{input}) return c(ctx, sr, opts...) } } func invokeByTransform[I, O, TOption any](t Transform[I, O, TOption]) Invoke[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { srInput := schema.StreamReaderFromArray([]I{input}) srOutput, err := t(ctx, srInput, opts...) if err != nil { return output, err } return defaultImplConcatStreamReader(srOutput) } } func streamByTransform[I, O, TOption any](t Transform[I, O, TOption]) Stream[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { srInput := schema.StreamReaderFromArray([]I{input}) return t(ctx, srInput, opts...) } } func streamByInvoke[I, O, TOption any](i Invoke[I, O, TOption]) Stream[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { out, err := i(ctx, input, opts...) if err != nil { return nil, err } return schema.StreamReaderFromArray([]O{out}), nil } } func streamByCollect[I, O, TOption any](c Collect[I, O, TOption]) Stream[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { srInput := schema.StreamReaderFromArray([]I{input}) out, err := c(ctx, srInput, opts...) if err != nil { return nil, err } return schema.StreamReaderFromArray([]O{out}), nil } } func collectByTransform[I, O, TOption any](t Transform[I, O, TOption]) Collect[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { srOutput, err := t(ctx, input, opts...) if err != nil { return output, err } return defaultImplConcatStreamReader(srOutput) } } func collectByInvoke[I, O, TOption any](i Invoke[I, O, TOption]) Collect[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { in, err := defaultImplConcatStreamReader(input) if err != nil { return output, err } return i(ctx, in, opts...) } } func collectByStream[I, O, TOption any](s Stream[I, O, TOption]) Collect[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { in, err := defaultImplConcatStreamReader(input) if err != nil { return output, err } srOutput, err := s(ctx, in, opts...) if err != nil { return output, err } return defaultImplConcatStreamReader(srOutput) } } func transformByStream[I, O, TOption any](s Stream[I, O, TOption]) Transform[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { in, err := defaultImplConcatStreamReader(input) if err != nil { return output, err } return s(ctx, in, opts...) } } func transformByCollect[I, O, TOption any](c Collect[I, O, TOption]) Transform[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { out, err := c(ctx, input, opts...) if err != nil { return output, err } return schema.StreamReaderFromArray([]O{out}), nil } } func transformByInvoke[I, O, TOption any](i Invoke[I, O, TOption]) Transform[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { in, err := defaultImplConcatStreamReader(input) if err != nil { return output, err } out, err := i(ctx, in, opts...) if err != nil { return output, err } return schema.StreamReaderFromArray([]O{out}), nil } } func newRunnablePacker[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], c Collect[I, O, TOption], t Transform[I, O, TOption], enableCallback bool) *runnablePacker[I, O, TOption] { r := &runnablePacker[I, O, TOption]{} if enableCallback { if i != nil { i = invokeWithCallbacks(i) } if s != nil { s = streamWithCallbacks(s) } if c != nil { c = collectWithCallbacks(c) } if t != nil { t = transformWithCallbacks(t) } } if i != nil { r.i = i } else if s != nil { r.i = invokeByStream(s) } else if c != nil { r.i = invokeByCollect(c) } else { r.i = invokeByTransform(t) } if s != nil { r.s = s } else if t != nil { r.s = streamByTransform(t) } else if i != nil { r.s = streamByInvoke(i) } else { r.s = streamByCollect(c) } if c != nil { r.c = c } else if t != nil { r.c = collectByTransform(t) } else if i != nil { r.c = collectByInvoke(i) } else { r.c = collectByStream(s) } if t != nil { r.t = t } else if s != nil { r.t = transformByStream(s) } else if c != nil { r.t = transformByCollect(c) } else { r.t = transformByInvoke(i) } return r } func toGenericRunnable[I, O any](cr *composableRunnable, ctxWrapper func(ctx context.Context, opts ...Option) context.Context) ( *runnablePacker[I, O, Option], error) { i := func(ctx context.Context, input I, opts ...Option) (output O, err error) { out, err := cr.i(ctx, input, toAnyList(opts)...) if err != nil { return output, err } to, ok := out.(O) if !ok { // When a nil is passed as an 'any' type, its original type information is lost, // becoming an untyped nil. This would cause type assertions to fail. // So if the output is nil and the target type O is an interface, we need to explicitly create a nil of type O. if out == nil && generic.TypeOf[O]().Kind() == reflect.Interface { var o O to = o } else { panic(newUnexpectedInputTypeErr(generic.TypeOf[O](), reflect.TypeOf(out))) } } return to, nil } t := func(ctx context.Context, input *schema.StreamReader[I], opts ...Option) (output *schema.StreamReader[O], err error) { in := packStreamReader(input) out, err := cr.t(ctx, in, toAnyList(opts)...) if err != nil { return nil, err } output, ok := unpackStreamReader[O](out) if !ok { panic("impossible") } return output, nil } r := newRunnablePacker(i, nil, nil, t, false) r.wrapRunnableCtx(ctxWrapper) return r, nil } func inputKeyedComposableRunnable(key string, r *composableRunnable) *composableRunnable { wrapper := *r wrapper.genericHelper = wrapper.genericHelper.forMapInput() i := r.i wrapper.i = func(ctx context.Context, input any, opts ...any) (output any, err error) { v, ok := input.(map[string]any)[key] if !ok { return nil, fmt.Errorf("cannot find input key: %s", key) } out, err := i(ctx, v, opts...) if err != nil { return nil, err } return out, nil } t := r.t wrapper.t = func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { nInput, ok := r.inputStreamFilter(key, input) if !ok { return nil, fmt.Errorf("inputStreamFilter failed, key= %s, node name= %s, err= %w", key, r.nodeInfo.name, err) } out, err := t(ctx, nInput, opts...) if err != nil { return nil, err } return out, nil } wrapper.inputType = generic.TypeOf[map[string]any]() return &wrapper } func outputKeyedComposableRunnable(key string, r *composableRunnable) *composableRunnable { wrapper := *r wrapper.genericHelper = wrapper.genericHelper.forMapOutput() i := r.i wrapper.i = func(ctx context.Context, input any, opts ...any) (output any, err error) { out, err := i(ctx, input, opts...) if err != nil { return nil, err } return map[string]any{key: out}, nil } t := r.t wrapper.t = func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { out, err := t(ctx, input, opts...) if err != nil { return nil, err } return out.withKey(key), nil } wrapper.outputType = generic.TypeOf[map[string]any]() return &wrapper } // composablePassthrough special runnable that passthrough input to output func composablePassthrough() *composableRunnable { r := &composableRunnable{isPassthrough: true, nodeInfo: &nodeInfo{}} r.i = func(ctx context.Context, input any, opts ...any) (output any, err error) { return input, nil } r.t = func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { return input, nil } r.meta = &executorMeta{ component: ComponentOfPassthrough, isComponentCallbackEnabled: false, componentImplType: "Passthrough", } return r } ================================================ FILE: compose/runnable_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "io" "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestRunnableLambda(t *testing.T) { ctx := context.Background() t.Run("invoke_to_runnable", func(t *testing.T) { rl := runnableLambda( func(ctx context.Context, input int, opts ...Option) (output string, err error) { return strconv.Itoa(input) + "+" + opts[0].options[0].(string), nil }, nil, nil, nil, false) ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { return ctx } gr, err := toGenericRunnable[int, string](rl, ctxWrapper) assert.NoError(t, err) out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) sri, swi := schema.Pipe[int](1) _ = swi.Send(10, nil) swi.Close() sriArr := sri.Copy(2) out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) }) t.Run("stream_to_runnable", func(t *testing.T) { rl := runnableLambda(nil, func(ctx context.Context, input int, opts ...Option) (output *schema.StreamReader[string], err error) { sro, swo := schema.Pipe[string](3) _ = swo.Send(strconv.Itoa(input), nil) _ = swo.Send("+", nil) _ = swo.Send(opts[0].options[0].(string), nil) swo.Close() return sro, nil }, nil, nil, false) ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { return ctx } gr, err := toGenericRunnable[int, string](rl, ctxWrapper) assert.NoError(t, err) out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) sri, swi := schema.Pipe[int](1) _ = swi.Send(10, nil) swi.Close() sriArr := sri.Copy(2) out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) }) t.Run("transform_to_runnable", func(t *testing.T) { rl := runnableLambda( nil, nil, nil, func(ctx context.Context, input *schema.StreamReader[int], opts ...Option) (output *schema.StreamReader[string], err error) { in, e := input.Recv() if errors.Is(e, io.EOF) { return nil, fmt.Errorf("unpected EOF") } input.Close() sro, swo := schema.Pipe[string](3) _ = swo.Send(strconv.Itoa(in), nil) _ = swo.Send("+", nil) _ = swo.Send(opts[0].options[0].(string), nil) swo.Close() return sro, nil }, false) ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { return ctx } gr, err := toGenericRunnable[int, string](rl, ctxWrapper) assert.NoError(t, err) out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) sri, swi := schema.Pipe[int](1) _ = swi.Send(10, nil) swi.Close() sriArr := sri.Copy(2) out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) }) t.Run("collect_to_runnable", func(t *testing.T) { rl := runnableLambda(nil, nil, func(ctx context.Context, input *schema.StreamReader[int], opts ...Option) (output string, err error) { in, e := input.Recv() if errors.Is(e, io.EOF) { return "", fmt.Errorf("unpected EOF") } input.Close() return strconv.Itoa(in) + "+" + opts[0].options[0].(string), nil }, nil, false) ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { return ctx } gr, err := toGenericRunnable[int, string](rl, ctxWrapper) assert.NoError(t, err) out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) sri, swi := schema.Pipe[int](1) _ = swi.Send(10, nil) swi.Close() sriArr := sri.Copy(2) out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) assert.NoError(t, err) assert.Equal(t, "10+100", out) sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) assert.NoError(t, err) out, err = concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "10+100", out) }) } ================================================ FILE: compose/state.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "reflect" "sync" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) // GenLocalState is a function that generates the state. type GenLocalState[S any] func(ctx context.Context) (state S) type stateKey struct{} type internalState struct { state any mu sync.Mutex parent *internalState } // StatePreHandler is a function called before the node is executed. // Notice: if user called Stream but with StatePreHandler, the StatePreHandler will read all stream chunks and merge them into a single object. type StatePreHandler[I, S any] func(ctx context.Context, in I, state S) (I, error) // StatePostHandler is a function called after the node is executed. // Notice: if user called Stream but with StatePostHandler, the StatePostHandler will read all stream chunks and merge them into a single object. type StatePostHandler[O, S any] func(ctx context.Context, out O, state S) (O, error) // StreamStatePreHandler is a function that is called before the node is executed with stream input and output. type StreamStatePreHandler[I, S any] func(ctx context.Context, in *schema.StreamReader[I], state S) (*schema.StreamReader[I], error) // StreamStatePostHandler is a function that is called after the node is executed with stream input and output. type StreamStatePostHandler[O, S any] func(ctx context.Context, out *schema.StreamReader[O], state S) (*schema.StreamReader[O], error) func convertPreHandler[I, S any](handler StatePreHandler[I, S]) *composableRunnable { rf := func(ctx context.Context, in I, opts ...any) (I, error) { cState, pMu, err := getState[S](ctx) if err != nil { return in, err } pMu.Lock() defer pMu.Unlock() return handler(ctx, in, cState) } return runnableLambda[I, I](rf, nil, nil, nil, false) } func convertPostHandler[O, S any](handler StatePostHandler[O, S]) *composableRunnable { rf := func(ctx context.Context, out O, opts ...any) (O, error) { cState, pMu, err := getState[S](ctx) if err != nil { return out, err } pMu.Lock() defer pMu.Unlock() return handler(ctx, out, cState) } return runnableLambda[O, O](rf, nil, nil, nil, false) } func streamConvertPreHandler[I, S any](handler StreamStatePreHandler[I, S]) *composableRunnable { rf := func(ctx context.Context, in *schema.StreamReader[I], opts ...any) (*schema.StreamReader[I], error) { cState, pMu, err := getState[S](ctx) if err != nil { return in, err } pMu.Lock() defer pMu.Unlock() return handler(ctx, in, cState) } return runnableLambda[I, I](nil, nil, nil, rf, false) } func streamConvertPostHandler[O, S any](handler StreamStatePostHandler[O, S]) *composableRunnable { rf := func(ctx context.Context, out *schema.StreamReader[O], opts ...any) (*schema.StreamReader[O], error) { cState, pMu, err := getState[S](ctx) if err != nil { return out, err } pMu.Lock() defer pMu.Unlock() return handler(ctx, out, cState) } return runnableLambda[O, O](nil, nil, nil, rf, false) } // ProcessState processes the state from the context in a concurrency-safe way. // This is the recommended way to access and modify state in custom nodes. // The provided function handler will be executed with exclusive access to the state (protected by mutex). // // State Lookup Behavior: // - If the requested state type exists in the current graph, it will be returned // - If not found in current graph, ProcessState will search in parent graph states (for nested graphs) // - This enables nested graphs to access state from their parent graphs // - Follows lexical scoping: inner state of the same type shadows outer state // // Concurrency Safety: // - ProcessState automatically locks the mutex of the state being accessed (current or parent level) // - Each state level has its own mutex, allowing concurrent access to different levels // - The lock is held for the entire duration of the handler function // // Note: This method will report an error if the state type doesn't match or state is not found in the context chain. // // Example - Basic usage in a single graph: // // lambdaFunc := func(ctx context.Context, in string, opts ...any) (string, error) { // err := compose.ProcessState[*MyState](ctx, func(ctx context.Context, state *MyState) error { // // Safely modify state // state.Count++ // return nil // }) // if err != nil { // return "", err // } // return in, nil // } // // Example - Nested graph accessing parent state: // // // In an inner graph node // innerNode := func(ctx context.Context, input string) (string, error) { // // Access parent graph's state // err := compose.ProcessState[*OuterState](ctx, func(ctx context.Context, s *OuterState) error { // s.Counter++ // Safely modify parent state // return nil // }) // if err != nil { // return "", err // } // // // Also access inner graph's own state // err = compose.ProcessState[*InnerState](ctx, func(ctx context.Context, s *InnerState) error { // s.Data = "processed" // return nil // }) // return input, nil // } func ProcessState[S any](ctx context.Context, handler func(context.Context, S) error) error { s, pMu, err := getState[S](ctx) if err != nil { return fmt.Errorf("get state from context fail: %w", err) } pMu.Lock() defer pMu.Unlock() return handler(ctx, s) } func getState[S any](ctx context.Context) (S, *sync.Mutex, error) { state := ctx.Value(stateKey{}) if state == nil { var s S return s, nil, fmt.Errorf("have not set state") } interState := state.(*internalState) for interState != nil { if cState, ok := interState.state.(S); ok { return cState, &interState.mu, nil } interState = interState.parent } var s S return s, nil, fmt.Errorf("cannot find state with type: %v in states chain, "+ "current state type: %v", generic.TypeOf[S](), reflect.TypeOf(state.(*internalState).state)) } ================================================ FILE: compose/state_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "io" "strings" "sync" "testing" "time" "unicode/utf8" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) type midStr string func TestStateGraphWithEdge(t *testing.T) { ctx := context.Background() const ( nodeOfL1 = "invokable" nodeOfL2 = "streamable" nodeOfL3 = "transformable" ) type testState struct { ms []string } gen := func(ctx context.Context) *testState { return &testState{} } sg := NewGraph[string, string](WithGenLocalState(gen)) l1 := InvokableLambda(func(ctx context.Context, in string) (out midStr, err error) { return midStr("InvokableLambda: " + in), nil }) l1StateToInput := func(ctx context.Context, in string, state *testState) (string, error) { state.ms = append(state.ms, in) return in, nil } l1StateToOutput := func(ctx context.Context, out midStr, state *testState) (midStr, error) { state.ms = append(state.ms, string(out)) return out, nil } err := sg.AddLambdaNode(nodeOfL1, l1, WithStatePreHandler(l1StateToInput), WithStatePostHandler(l1StateToOutput)) assert.NoError(t, err) l2 := StreamableLambda(func(ctx context.Context, input midStr) (output *schema.StreamReader[string], err error) { outStr := "StreamableLambda: " + string(input) sr, sw := schema.Pipe[string](utf8.RuneCountInString(outStr)) go func() { for _, field := range strings.Fields(outStr) { sw.Send(field+" ", nil) } sw.Close() }() return sr, nil }) l2StateToOutput := func(ctx context.Context, out string, state *testState) (string, error) { state.ms = append(state.ms, out) return out, nil } err = sg.AddLambdaNode(nodeOfL2, l2, WithStatePostHandler(l2StateToOutput)) assert.NoError(t, err) l3 := TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) ( output *schema.StreamReader[string], err error) { prefix := "TransformableLambda: " sr, sw := schema.Pipe[string](20) go func() { for _, field := range strings.Fields(prefix) { sw.Send(field+" ", nil) } defer input.Close() for { chunk, err := input.Recv() if err != nil { if err == io.EOF { break } // TODO: how to trace this kind of error in the goroutine of processing stream sw.Send(chunk, err) break } sw.Send(chunk, nil) } sw.Close() }() return sr, nil }) l3StateToOutput := func(ctx context.Context, out string, state *testState) (string, error) { state.ms = append(state.ms, out) assert.Len(t, state.ms, 4) return out, nil } err = sg.AddLambdaNode(nodeOfL3, l3, WithStatePostHandler(l3StateToOutput)) assert.NoError(t, err) err = sg.AddEdge(START, nodeOfL1) assert.NoError(t, err) err = sg.AddEdge(nodeOfL1, nodeOfL2) assert.NoError(t, err) err = sg.AddEdge(nodeOfL2, nodeOfL3) assert.NoError(t, err) err = sg.AddEdge(nodeOfL3, END) assert.NoError(t, err) run, err := sg.Compile(ctx) assert.NoError(t, err) out, err := run.Invoke(ctx, "how are you") assert.NoError(t, err) assert.Equal(t, "TransformableLambda: StreamableLambda: InvokableLambda: how are you ", out) stream, err := run.Stream(ctx, "how are you") assert.NoError(t, err) out, err = concatStreamReader(stream) assert.NoError(t, err) assert.Equal(t, "TransformableLambda: StreamableLambda: InvokableLambda: how are you ", out) sr, sw := schema.Pipe[string](1) sw.Send("how are you", nil) sw.Close() stream, err = run.Transform(ctx, sr) assert.NoError(t, err) out, err = concatStreamReader(stream) assert.NoError(t, err) assert.Equal(t, "TransformableLambda: StreamableLambda: InvokableLambda: how are you ", out) } func TestStateGraphUtils(t *testing.T) { t.Run("getState_success", func(t *testing.T) { type testStruct struct { UserID int64 } ctx := context.Background() ctx = context.WithValue(ctx, stateKey{}, &internalState{ state: &testStruct{UserID: 10}, }) var userID int64 err := ProcessState[*testStruct](ctx, func(_ context.Context, state *testStruct) error { userID = state.UserID return nil }) assert.NoError(t, err) assert.Equal(t, int64(10), userID) }) t.Run("getState_nil", func(t *testing.T) { type testStruct struct { UserID int64 } ctx := context.Background() ctx = context.WithValue(ctx, stateKey{}, &internalState{}) err := ProcessState[*testStruct](ctx, func(_ context.Context, state *testStruct) error { return nil }) assert.ErrorContains(t, err, "cannot find state with type: *compose.testStruct in states chain, "+ "current state type: ") }) t.Run("getState_type_error", func(t *testing.T) { type testStruct struct { UserID int64 } ctx := context.Background() ctx = context.WithValue(ctx, stateKey{}, &internalState{ state: &testStruct{UserID: 10}, }) err := ProcessState[string](ctx, func(_ context.Context, state string) error { return nil }) assert.ErrorContains(t, err, "cannot find state with type: string in states chain, "+ "current state type: *compose.testStruct") }) } func TestStateChain(t *testing.T) { ctx := context.Background() type testState struct { Field1 string Field2 string } sc := NewChain[string, string](WithGenLocalState(func(ctx context.Context) (state *testState) { return &testState{} })) r, err := sc.AppendLambda(InvokableLambda(func(ctx context.Context, input string) (output string, err error) { err = ProcessState[*testState](ctx, func(_ context.Context, state *testState) error { state.Field1 = "node1" return nil }) if err != nil { return "", err } return input, nil }), WithStatePostHandler(func(ctx context.Context, out string, state *testState) (string, error) { state.Field2 = "node2" return out, nil })). AppendLambda(InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil }), WithStatePreHandler(func(ctx context.Context, in string, state *testState) (string, error) { return in + state.Field1 + state.Field2, nil })).Compile(ctx) if err != nil { t.Fatal(err) } result, err := r.Invoke(ctx, "start") if err != nil { t.Fatal(err) } if result != "startnode1node2" { t.Fatal("result is unexpected") } } func TestStreamState(t *testing.T) { type testState struct { Field1 string } ctx := context.Background() s := &testState{Field1: "1"} g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testState) { return s })) err := g.AddLambdaNode("1", TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) { return input, nil }), WithStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[string], state *testState) (*schema.StreamReader[string], error) { sr, sw := schema.Pipe[string](5) for i := 0; i < 5; i++ { sw.Send(state.Field1, nil) } sw.Close() return sr, nil }), WithStreamStatePostHandler(func(ctx context.Context, in *schema.StreamReader[string], state *testState) (*schema.StreamReader[string], error) { ss := in.Copy(2) for { chunk, err := ss[0].Recv() if err == io.EOF { return ss[1], nil } if err != nil { return nil, err } state.Field1 += chunk } })) if err != nil { t.Fatal(err) } err = g.AddEdge(START, "1") if err != nil { t.Fatal(err) } err = g.AddEdge("1", END) if err != nil { t.Fatal(err) } r, err := g.Compile(ctx) if err != nil { t.Fatal(err) } sr, _ := schema.Pipe[string](1) streamResult, err := r.Transform(ctx, sr) if err != nil { t.Fatal(err) } if s.Field1 != "111111" { t.Fatal("state is unexpected") } for i := 0; i < 5; i++ { chunk, err := streamResult.Recv() if err != nil { t.Fatal(err) } if chunk != "1" { t.Fatal("result is unexpected") } } _, err = streamResult.Recv() if err != io.EOF { t.Fatal("result is unexpected") } } // Nested Graph State Tests type NestedOuterState struct { Value string Counter int } type NestedInnerState struct { Value string } func init() { schema.RegisterName[*NestedOuterState]("NestedOuterState") schema.RegisterName[*NestedInnerState]("NestedInnerState") } func TestNestedGraphStateAccess(t *testing.T) { // Test that inner graph can access outer graph's state genOuterState := func(ctx context.Context) *NestedOuterState { return &NestedOuterState{Value: "outer", Counter: 0} } genInnerState := func(ctx context.Context) *NestedInnerState { return &NestedInnerState{Value: "inner"} } innerNode := func(ctx context.Context, input string) (string, error) { // Access both inner and outer state var outerValue string err := ProcessState(ctx, func(ctx context.Context, s *NestedOuterState) error { outerValue = s.Value return nil }) if err != nil { return "", err } var innerValue string err = ProcessState(ctx, func(ctx context.Context, s *NestedInnerState) error { innerValue = s.Value return nil }) if err != nil { return "", err } return fmt.Sprintf("%s_inner=%s_outer=%s", input, innerValue, outerValue), nil } innerGraph := NewGraph[string, string](WithGenLocalState(genInnerState)) _ = innerGraph.AddLambdaNode("inner_node", InvokableLambda(innerNode)) _ = innerGraph.AddEdge(START, "inner_node") _ = innerGraph.AddEdge("inner_node", END) outerGraph := NewGraph[string, string](WithGenLocalState(genOuterState)) _ = outerGraph.AddGraphNode("inner_graph", innerGraph) _ = outerGraph.AddEdge(START, "inner_graph") _ = outerGraph.AddEdge("inner_graph", END) r, err := outerGraph.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), "start") assert.NoError(t, err) assert.Equal(t, "start_inner=inner_outer=outer", out) } func TestNestedGraphStateShadowing(t *testing.T) { // Test that inner state shadows outer state of the same type (lexical scoping) type CommonState struct { Value string } genOuterState := func(ctx context.Context) *CommonState { return &CommonState{Value: "outer"} } genInnerState := func(ctx context.Context) *CommonState { return &CommonState{Value: "inner"} } innerNode := func(ctx context.Context, input string) (string, error) { var value string err := ProcessState(ctx, func(ctx context.Context, s *CommonState) error { // Should see "inner" because inner state shadows outer state value = s.Value return nil }) if err != nil { return "", err } return input + "_" + value, nil } innerGraph := NewGraph[string, string](WithGenLocalState(genInnerState)) _ = innerGraph.AddLambdaNode("inner_node", InvokableLambda(innerNode)) _ = innerGraph.AddEdge(START, "inner_node") _ = innerGraph.AddEdge("inner_node", END) outerGraph := NewGraph[string, string](WithGenLocalState(genOuterState)) _ = outerGraph.AddGraphNode("inner_graph", innerGraph) _ = outerGraph.AddEdge(START, "inner_graph") _ = outerGraph.AddEdge("inner_graph", END) r, err := outerGraph.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), "start") assert.NoError(t, err) assert.Equal(t, "start_inner", out) } func TestNestedGraphStateAfterResume(t *testing.T) { // Test that state parent linking works correctly after resume // when the outer state is restored from checkpoint (new instance) genOuterState := func(ctx context.Context) *NestedOuterState { return &NestedOuterState{Value: "outer", Counter: 0} } genInnerState := func(ctx context.Context) *NestedInnerState { return &NestedInnerState{Value: "inner"} } // Node that modifies outer state outerNode := func(ctx context.Context, input string) (string, error) { err := ProcessState(ctx, func(ctx context.Context, s *NestedOuterState) error { s.Counter = 42 return nil }) if err != nil { return "", err } return input, nil } // Inner node that reads outer state innerNode := func(ctx context.Context, input string) (string, error) { var outerCounter int var outerValue string err := ProcessState(ctx, func(ctx context.Context, s *NestedOuterState) error { // Should see the modified counter value from the restored state outerCounter = s.Counter outerValue = s.Value return nil }) if err != nil { return "", err } return fmt.Sprintf("%s_counter=%d_value=%s", input, outerCounter, outerValue), nil } innerGraph := NewGraph[string, string](WithGenLocalState(genInnerState)) _ = innerGraph.AddLambdaNode("inner_node", InvokableLambda(innerNode)) _ = innerGraph.AddEdge(START, "inner_node") _ = innerGraph.AddEdge("inner_node", END) outerGraph := NewGraph[string, string](WithGenLocalState(genOuterState)) _ = outerGraph.AddLambdaNode("outer_node", InvokableLambda(outerNode)) _ = outerGraph.AddGraphNode("inner_graph", innerGraph, WithGraphCompileOptions(WithInterruptBeforeNodes([]string{"inner_node"}))) _ = outerGraph.AddEdge(START, "outer_node") _ = outerGraph.AddEdge("outer_node", "inner_graph") _ = outerGraph.AddEdge("inner_graph", END) store := newInMemoryStore() r, err := outerGraph.Compile(context.Background(), WithCheckPointStore(store)) assert.NoError(t, err) // First run - should interrupt after modifying outer state _, err = r.Invoke(context.Background(), "start", WithCheckPointID("state_resume_test")) assert.Error(t, err) // Resume - outer state should be restored with Counter=42 // Inner graph should link to this restored outer state out, err := r.Invoke(context.Background(), "start", WithCheckPointID("state_resume_test")) assert.NoError(t, err) assert.Equal(t, "start_counter=42_value=outer", out) } func TestLambdaNestedGraphStateAccess(t *testing.T) { // Test that inner graph invoked from a lambda can access outer graph's state // This tests the case: outer graph -> lambda node -> inner graph (using CompositeInterrupt) genOuterState := func(ctx context.Context) *NestedOuterState { return &NestedOuterState{Value: "outer", Counter: 100} } genInnerState := func(ctx context.Context) *NestedInnerState { return &NestedInnerState{Value: "inner"} } // Inner node that accesses outer state innerNode := func(ctx context.Context, input string) (string, error) { var outerValue string var outerCounter int err := ProcessState(ctx, func(ctx context.Context, s *NestedOuterState) error { outerValue = s.Value outerCounter = s.Counter return nil }) if err != nil { return "", err } var innerValue string err = ProcessState(ctx, func(ctx context.Context, s *NestedInnerState) error { innerValue = s.Value return nil }) if err != nil { return "", err } return fmt.Sprintf("%s_inner=%s_outer=%s_%d", input, innerValue, outerValue, outerCounter), nil } // Build inner graph innerGraph := NewGraph[string, string](WithGenLocalState(genInnerState)) _ = innerGraph.AddLambdaNode("inner_node", InvokableLambda(innerNode)) _ = innerGraph.AddEdge(START, "inner_node") _ = innerGraph.AddEdge("inner_node", END) // Compile inner graph as a standalone runnable innerRunnable, err := innerGraph.Compile(context.Background()) assert.NoError(t, err) // Lambda that invokes the inner graph lambdaNode := InvokableLambda(func(ctx context.Context, input string) (string, error) { // Simply invoke the inner graph - state context is passed through return innerRunnable.Invoke(ctx, input) }) // Build outer graph outerGraph := NewGraph[string, string](WithGenLocalState(genOuterState)) _ = outerGraph.AddLambdaNode("lambda_with_graph", lambdaNode) _ = outerGraph.AddEdge(START, "lambda_with_graph") _ = outerGraph.AddEdge("lambda_with_graph", END) r, err := outerGraph.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), "start") assert.NoError(t, err) assert.Equal(t, "start_inner=inner_outer=outer_100", out) } func TestLambdaNestedGraphStateAfterResume(t *testing.T) { // Test that state parent linking works correctly after resume // in the lambda-nested case (outer graph -> lambda -> inner graph) genOuterState := func(ctx context.Context) *NestedOuterState { return &NestedOuterState{Value: "outer", Counter: 0} } genInnerState := func(ctx context.Context) *NestedInnerState { return &NestedInnerState{Value: "inner"} } // Outer node that modifies state outerNode := func(ctx context.Context, input string) (string, error) { err := ProcessState(ctx, func(ctx context.Context, s *NestedOuterState) error { s.Counter = 99 return nil }) if err != nil { return "", err } return input, nil } // Inner lambda that interrupts on first run, reads outer state on resume innerLambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { wasInterrupted, _, _ := GetInterruptState[*NestedInnerState](ctx) if !wasInterrupted { // First run: interrupt return "", StatefulInterrupt(ctx, "inner interrupt", &NestedInnerState{Value: "inner"}) } // Resumed: read outer state var outerCounter int var outerValue string err := ProcessState(ctx, func(ctx context.Context, s *NestedOuterState) error { // Should see the modified counter from the restored state outerCounter = s.Counter outerValue = s.Value return nil }) if err != nil { return "", err } return fmt.Sprintf("%s_counter=%d_value=%s", input, outerCounter, outerValue), nil }) // Build inner graph innerGraph := NewGraph[string, string](WithGenLocalState(genInnerState)) _ = innerGraph.AddLambdaNode("inner_lambda", innerLambda) _ = innerGraph.AddEdge(START, "inner_lambda") _ = innerGraph.AddEdge("inner_lambda", END) // Compile inner graph as standalone runnable with checkpoint support innerRunnable, err := innerGraph.Compile(context.Background(), WithGraphName("inner"), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // Composite lambda that invokes the inner graph and handles interrupts compositeLambda := InvokableLambda(func(ctx context.Context, input string) (string, error) { output, err := innerRunnable.Invoke(ctx, input, WithCheckPointID("inner-cp")) if err != nil { _, isInterrupt := ExtractInterruptInfo(err) if !isInterrupt { return "", err } // Wrap the interrupt using CompositeInterrupt return "", CompositeInterrupt(ctx, "composite interrupt", nil, err) } return output, nil }) // Build outer graph outerGraph := NewGraph[string, string](WithGenLocalState(genOuterState)) _ = outerGraph.AddLambdaNode("outer_node", InvokableLambda(outerNode)) _ = outerGraph.AddLambdaNode("composite_lambda", compositeLambda) _ = outerGraph.AddEdge(START, "outer_node") _ = outerGraph.AddEdge("outer_node", "composite_lambda") _ = outerGraph.AddEdge("composite_lambda", END) // Compile outer graph outerRunnable, err := outerGraph.Compile(context.Background(), WithGraphName("root"), WithCheckPointStore(newInMemoryStore())) assert.NoError(t, err) // First run - should interrupt after modifying outer state checkPointID := "lambda_state_resume_test" _, err = outerRunnable.Invoke(context.Background(), "start", WithCheckPointID(checkPointID)) assert.Error(t, err) interruptInfo, isInterrupt := ExtractInterruptInfo(err) assert.True(t, isInterrupt) // Resume - outer state should be restored with Counter=99 // Inner lambda should link to this restored outer state ctx := ResumeWithData(context.Background(), interruptInfo.InterruptContexts[0].ID, nil) out, err := outerRunnable.Invoke(ctx, "start", WithCheckPointID(checkPointID)) assert.NoError(t, err) // Verify the inner lambda saw the modified counter from the restored outer state assert.Contains(t, out, "counter=99") assert.Contains(t, out, "value=outer") } func TestNestedGraphStateConcurrency(t *testing.T) { // Test that concurrent access to parent and child states uses correct locks // This verifies that ProcessState properly locks the parent state's mutex when accessing it genOuterState := func(ctx context.Context) *NestedOuterState { return &NestedOuterState{Value: "outer", Counter: 0} } genInnerState := func(ctx context.Context) *NestedInnerState { return &NestedInnerState{Value: "inner"} } // Inner node that concurrently modifies both outer and inner state innerNode := func(ctx context.Context, input string) (string, error) { var wg sync.WaitGroup errors := make(chan error, 20) // Launch 10 goroutines that modify outer state // If locks don't work correctly, we'll see race conditions for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() err := ProcessState(ctx, func(ctx context.Context, s *NestedOuterState) error { // ProcessState should hold the parent's lock during this entire function current := s.Counter time.Sleep(time.Millisecond) // Simulate work s.Counter = current + 1 return nil }) if err != nil { errors <- err } }() } // Launch 10 goroutines that modify inner state for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() err := ProcessState(ctx, func(ctx context.Context, s *NestedInnerState) error { // This uses the inner state's own lock return nil }) if err != nil { errors <- err } }() } wg.Wait() close(errors) // Check for errors for err := range errors { return "", err } return input, nil } innerGraph := NewGraph[string, string](WithGenLocalState(genInnerState)) _ = innerGraph.AddLambdaNode("inner_node", InvokableLambda(innerNode)) _ = innerGraph.AddEdge(START, "inner_node") _ = innerGraph.AddEdge("inner_node", END) outerGraph := NewGraph[string, string](WithGenLocalState(genOuterState)) _ = outerGraph.AddGraphNode("inner_graph", innerGraph) _ = outerGraph.AddEdge(START, "inner_graph") _ = outerGraph.AddEdge("inner_graph", END) r, err := outerGraph.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), "start") assert.NoError(t, err) // Note: This test is primarily validated by running with -race flag // If locks don't work correctly, the race detector will catch it } ================================================ FILE: compose/stream_concat.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "errors" "io" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/schema" ) // RegisterStreamChunkConcatFunc registers a function to concat stream chunks. // It's required when you want to concat stream chunks of a specific type. // for example you call Invoke() but node only implements Stream(). // call at process init // not thread safe // eg. // // type testStruct struct { // field1 string // field2 int // } // compose.RegisterStreamChunkConcatFunc(func(items []testStruct) (testStruct, error) { // return testStruct{ // field1: items[1].field1, // may implement inplace logic by your scenario // field2: items[0].field2 + items[1].field2, // }, nil // }) func RegisterStreamChunkConcatFunc[T any](fn func([]T) (T, error)) { internal.RegisterStreamChunkConcatFunc(fn) } var emptyStreamConcatErr = errors.New("stream reader is empty, concat fail") func concatStreamReader[T any](sr *schema.StreamReader[T]) (T, error) { defer sr.Close() var items []T for { chunk, err := sr.Recv() if err != nil { if err == io.EOF { break } if _, ok := schema.GetSourceName(err); ok { continue } var t T return t, newStreamReadError(err) } items = append(items, chunk) } if len(items) == 0 { var t T return t, emptyStreamConcatErr } if len(items) == 1 { return items[0], nil } res, err := internal.ConcatItems(items) if err != nil { var t T return t, err } return res, nil } ================================================ FILE: compose/stream_concat_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "errors" "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/schema" ) type tStreamConcatItemForTest struct { s string } func concatTStreamForTest(items []tStreamConcatItemForTest) (tStreamConcatItemForTest, error) { var s string for _, item := range items { s += item.s } return tStreamConcatItemForTest{s: s}, nil } func concatIntForTest(items []int) (int, error) { var i int for _, item := range items { i += item } return i, nil } type tConcatErrForTest struct{} func concatTStreamError(_ []tConcatErrForTest) (tConcatErrForTest, error) { return tConcatErrForTest{}, errors.New("test error") } func TestConcatRegistry(t *testing.T) { RegisterStreamChunkConcatFunc(concatTStreamForTest) sr, sw := schema.Pipe[tStreamConcatItemForTest](10) go func() { for i := 0; i < 10; i++ { sw.Send(tStreamConcatItemForTest{s: strconv.Itoa(i)}, nil) } sw.Close() }() lastVal, err := concatStreamReader(sr) assert.Nil(t, err) assert.Equal(t, "0123456789", lastVal.s) } func TestStringConcat(t *testing.T) { sr, sw := schema.Pipe[string](10) go func() { for i := 0; i < 10; i++ { sw.Send(strconv.Itoa(i), nil) } sw.Close() }() lastVal, err := concatStreamReader(sr) assert.Nil(t, err) assert.Equal(t, "0123456789", lastVal) } func TestMessageConcat(t *testing.T) { sr, sw := schema.Pipe[*schema.Message](10) go func() { for i := 0; i < 10; i++ { content := schema.UserMessage(strconv.Itoa(i)) if i%4 == 0 { content.Extra = map[string]any{ "key_1": strconv.Itoa(i), strconv.Itoa(i): strconv.Itoa(i), } } sw.Send(content, nil) } sw.Close() }() lastVal, err := concatStreamReader(sr) assert.Nil(t, err) assert.Equal(t, "0123456789", lastVal.Content) assert.Len(t, lastVal.Extra, 4) assert.Equal(t, map[string]any{ "key_1": "048", "0": "0", "4": "4", "8": "8", }, lastVal.Extra) } func TestMapConcat(t *testing.T) { RegisterStreamChunkConcatFunc(concatTStreamForTest) RegisterStreamChunkConcatFunc(concatIntForTest) t.Run("simple map", func(t *testing.T) { sr, sw := schema.Pipe[map[string]any](10) go func() { for i := 0; i < 10; i++ { sw.Send(map[string]any{ "string": strconv.Itoa(i), "custom_concat": tStreamConcatItemForTest{s: strconv.Itoa(9 - i)}, "count": i, }, nil) } sw.Close() }() lastVal, err := concatStreamReader(sr) assert.Nil(t, err) assert.Equal(t, "0123456789", lastVal["string"]) assert.Equal(t, "9876543210", lastVal["custom_concat"].(tStreamConcatItemForTest).s) assert.Equal(t, 45, lastVal["count"]) }) t.Run("complex map", func(t *testing.T) { sr, sw := schema.Pipe[map[string]any](10) go func() { for i := 0; i < 10; i++ { // 嵌套 map, 仅允许第一层做类型合并,第二层直接覆盖 sw.Send(map[string]any{ // 嵌套 map "string": strconv.Itoa(i), "deep_map": map[string]any{ "message": &schema.Message{ Content: strconv.Itoa(i), }, "custom_concat_deep": tStreamConcatItemForTest{s: strconv.Itoa(9 - i)}, "count": i, }, "custom_concat": tStreamConcatItemForTest{s: strconv.Itoa(9 - i)}, "count": i, }, nil) } sw.Close() }() lastVal, err := concatStreamReader(sr) assert.Nil(t, err) assert.Equal(t, "0123456789", lastVal["string"]) assert.Equal(t, 45, lastVal["count"]) assert.Equal(t, "0123456789", lastVal["deep_map"].(map[string]any)["message"].(*schema.Message).Content) assert.Equal(t, "9876543210", lastVal["deep_map"].(map[string]any)["custom_concat_deep"].(tStreamConcatItemForTest).s) assert.Equal(t, 45, lastVal["deep_map"].(map[string]any)["count"]) }) } func TestConcatError(t *testing.T) { t.Run("map type not equal", func(t *testing.T) { a := map[string]any{ "str": "string_01", "x": "string_in_a", } b := map[string]any{ "str": "string_02", "x": 123, } _, err := internal.ConcatItems([]map[string]any{a, b}) assert.NotNil(t, err) }) t.Run("merge error", func(t *testing.T) { RegisterStreamChunkConcatFunc(concatTStreamError) _, err := internal.ConcatItems([]tConcatErrForTest{{}, {}}) assert.NotNil(t, err) }) } func TestConcatSliceValue(t *testing.T) { type testStruct struct { A string } s := []testStruct{{}, {A: "123"}, {}} result, err := internal.ConcatItems(s) assert.Nil(t, err) assert.Equal(t, testStruct{A: "123"}, result) s = []testStruct{{}, {}, {}} result, err = internal.ConcatItems(s) assert.Nil(t, err) assert.Equal(t, testStruct{}, result) } ================================================ FILE: compose/stream_reader.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "reflect" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) type streamReader interface { copy(n int) []streamReader getType() reflect.Type getChunkType() reflect.Type merge([]streamReader) streamReader withKey(string) streamReader close() toAnyStreamReader() *schema.StreamReader[any] mergeWithNames([]streamReader, []string) streamReader } type streamReaderPacker[T any] struct { sr *schema.StreamReader[T] } func (srp streamReaderPacker[T]) close() { srp.sr.Close() } func (srp streamReaderPacker[T]) copy(n int) []streamReader { ret := make([]streamReader, n) srs := srp.sr.Copy(n) for i := 0; i < n; i++ { ret[i] = streamReaderPacker[T]{srs[i]} } return ret } func (srp streamReaderPacker[T]) getType() reflect.Type { return reflect.TypeOf(srp.sr) } func (srp streamReaderPacker[T]) getChunkType() reflect.Type { return generic.TypeOf[T]() } func (srp streamReaderPacker[T]) toStreamReaders(srs []streamReader) []*schema.StreamReader[T] { ret := make([]*schema.StreamReader[T], len(srs)+1) ret[0] = srp.sr for i := 1; i < len(ret); i++ { sr, ok := unpackStreamReader[T](srs[i-1]) if !ok { return nil } ret[i] = sr } return ret } func (srp streamReaderPacker[T]) merge(isrs []streamReader) streamReader { srs := srp.toStreamReaders(isrs) sr := schema.MergeStreamReaders(srs) return packStreamReader(sr) } func (srp streamReaderPacker[T]) mergeWithNames(isrs []streamReader, names []string) streamReader { srs := srp.toStreamReaders(isrs) sr := schema.InternalMergeNamedStreamReaders(srs, names) return packStreamReader(sr) } func (srp streamReaderPacker[T]) withKey(key string) streamReader { cvt := func(v T) (map[string]any, error) { return map[string]any{key: v}, nil } ret := schema.StreamReaderWithConvert[T, map[string]any](srp.sr, cvt) return packStreamReader(ret) } func (srp streamReaderPacker[T]) toAnyStreamReader() *schema.StreamReader[any] { return schema.StreamReaderWithConvert(srp.sr, func(t T) (any, error) { return t, nil }) } func packStreamReader[T any](sr *schema.StreamReader[T]) streamReader { return streamReaderPacker[T]{sr} } func unpackStreamReader[T any](isr streamReader) (*schema.StreamReader[T], bool) { c, ok := isr.(streamReaderPacker[T]) if ok { return c.sr, true } typ := generic.TypeOf[T]() if typ.Kind() == reflect.Interface { return schema.StreamReaderWithConvert(isr.toAnyStreamReader(), func(t any) (T, error) { return t.(T), nil }), true } return nil, false } ================================================ FILE: compose/stream_reader_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "io" "reflect" "testing" "github.com/cloudwego/eino/schema" "github.com/stretchr/testify/assert" ) func TestArrayStreamMerge(t *testing.T) { t.Run("unpack_to_equal_type", func(t *testing.T) { a1 := []int{1, 2, 3} a2 := []int{4, 5, 6} a3 := []int{7, 8, 9} s1 := schema.StreamReaderFromArray(a1) s2 := schema.StreamReaderFromArray(a2) s3 := schema.StreamReaderFromArray(a3) sp1 := streamReaderPacker[int]{sr: s1} sp2 := streamReaderPacker[int]{sr: s2} sp3 := streamReaderPacker[int]{sr: s3} sp := sp1.merge([]streamReader{sp2, sp3}) sr, ok := unpackStreamReader[int](sp) if !ok { t.Fatal("unexpected") } defer sr.Close() var result []int for { chunk, err := sr.Recv() if err == io.EOF { break } assert.Nil(t, err) result = append(result, chunk) } if !reflect.DeepEqual(result, append(append(a1, a2...), a3...)) { t.Fatalf("result: %v error", result) } }) t.Run("unpack_to_father_type", func(t *testing.T) { a1 := []*doctor{{say: "a"}, {say: "b"}, {say: "c"}} a2 := []*doctor{{say: "d"}, {say: "e"}, {say: "f"}} a3 := []*doctor{{say: "g"}, {say: "h"}, {say: "i"}} s1 := schema.StreamReaderFromArray(a1) s2 := schema.StreamReaderFromArray(a2) s3 := schema.StreamReaderFromArray(a3) sp1 := streamReaderPacker[*doctor]{sr: s1} sp2 := streamReaderPacker[*doctor]{sr: s2} sp3 := streamReaderPacker[*doctor]{sr: s3} sp := sp1.merge([]streamReader{sp2, sp3}) sr, ok := unpackStreamReader[person](sp) assert.True(t, ok) defer sr.Close() var result []person for { chunk, err := sr.Recv() if err == io.EOF { break } assert.Nil(t, err) result = append(result, chunk) } baseline := append(append(a1, a2...), a3...) assert.Len(t, result, len(baseline)) for idx := range result { assert.Equal(t, baseline[idx].say, result[idx].Say()) } }) } ================================================ FILE: compose/tool_node.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "runtime/debug" "sync" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) type toolsNodeOptions struct { ToolOptions []tool.Option ToolList []tool.BaseTool } // ToolsNodeOption is the option func type for ToolsNode. type ToolsNodeOption func(o *toolsNodeOptions) // WithToolOption adds tool options to the ToolsNode. func WithToolOption(opts ...tool.Option) ToolsNodeOption { return func(o *toolsNodeOptions) { o.ToolOptions = append(o.ToolOptions, opts...) } } // WithToolList sets the tool list for the ToolsNode. func WithToolList(tool ...tool.BaseTool) ToolsNodeOption { return func(o *toolsNodeOptions) { o.ToolList = tool } } // ToolsNode represents a node capable of executing tools within a graph. // The Graph Node interface is defined as follows: // // Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error) // Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) // // Input: An AssistantMessage containing ToolCalls // Output: An array of ToolMessage where the order of elements corresponds to the order of ToolCalls in the input type ToolsNode struct { tuple *toolsTuple unknownToolHandler func(ctx context.Context, name, input string) (string, error) executeSequentially bool toolArgumentsHandler func(ctx context.Context, name, input string) (string, error) toolCallMiddlewares []InvokableToolMiddleware streamToolCallMiddlewares []StreamableToolMiddleware enhancedToolCallMiddlewares []EnhancedInvokableToolMiddleware enhancedStreamToolCallMiddlewares []EnhancedStreamableToolMiddleware } // ToolInput represents the input parameters for a tool call execution. type ToolInput struct { // Name is the name of the tool to be executed. Name string // Arguments contains the arguments for the tool call. Arguments string // CallID is the unique identifier for this tool call. CallID string // CallOptions contains tool options for the execution. CallOptions []tool.Option } // ToolOutput represents the result of a non-streaming tool call execution. type ToolOutput struct { // Result contains the string output from the tool execution. Result string } // StreamToolOutput represents the result of a streaming tool call execution. type StreamToolOutput struct { // Result is a stream reader that provides access to the tool's streaming output. Result *schema.StreamReader[string] } // EnhancedInvokableToolOutput represents the result of a non-streaming enhanced tool call execution. // It supports returning structured multimodal content (text, images, audio, video, files) from tools. type EnhancedInvokableToolOutput struct { // Result contains the structured multimodal output from the tool execution. Result *schema.ToolResult } // EnhancedStreamableToolOutput represents the result of a streaming enhanced tool call execution. // It provides a stream reader for accessing multimodal content progressively. type EnhancedStreamableToolOutput struct { // Result is a stream reader that provides access to the tool's streaming multimodal output. Result *schema.StreamReader[*schema.ToolResult] } // InvokableToolEndpoint is the function signature for non-streaming tool calls. type InvokableToolEndpoint func(ctx context.Context, input *ToolInput) (*ToolOutput, error) // StreamableToolEndpoint is the function signature for streaming tool calls. type StreamableToolEndpoint func(ctx context.Context, input *ToolInput) (*StreamToolOutput, error) type EnhancedInvokableToolEndpoint func(ctx context.Context, input *ToolInput) (*EnhancedInvokableToolOutput, error) type EnhancedStreamableToolEndpoint func(ctx context.Context, input *ToolInput) (*EnhancedStreamableToolOutput, error) // InvokableToolMiddleware is a function that wraps InvokableToolEndpoint to add custom processing logic. // It can be used to intercept, modify, or enhance tool call execution for non-streaming tools. type InvokableToolMiddleware func(InvokableToolEndpoint) InvokableToolEndpoint // StreamableToolMiddleware is a function that wraps StreamableToolEndpoint to add custom processing logic. // It can be used to intercept, modify, or enhance tool call execution for streaming tools. type StreamableToolMiddleware func(StreamableToolEndpoint) StreamableToolEndpoint type EnhancedInvokableToolMiddleware func(EnhancedInvokableToolEndpoint) EnhancedInvokableToolEndpoint type EnhancedStreamableToolMiddleware func(EnhancedStreamableToolEndpoint) EnhancedStreamableToolEndpoint // ToolMiddleware groups middleware hooks for invokable and streamable tool calls. type ToolMiddleware struct { // Invokable contains middleware function for non-streaming tool calls. // Note: This middleware only applies to tools that implement the InvokableTool interface. Invokable InvokableToolMiddleware // Streamable contains middleware function for streaming tool calls. // Note: This middleware only applies to tools that implement the StreamableTool interface. Streamable StreamableToolMiddleware // EnhancedInvokable contains middleware function for non-streaming enhanced tool calls. // Note: This middleware only applies to tools that implement the EnhancedInvokableTool interface. EnhancedInvokable EnhancedInvokableToolMiddleware // EnhancedStreamable contains middleware function for streaming enhanced tool calls. // Note: This middleware only applies to tools that implement the EnhancedStreamableTool interface. EnhancedStreamable EnhancedStreamableToolMiddleware } // ToolsNodeConfig is the config for ToolsNode. type ToolsNodeConfig struct { // Tools specify the list of tools can be called which are BaseTool but must implement InvokableTool or StreamableTool. Tools []tool.BaseTool // UnknownToolsHandler handles tool calls for non-existent tools when LLM hallucinates. // This field is optional. When not set, calling a non-existent tool will result in an error. // When provided, if the LLM attempts to call a tool that doesn't exist in the Tools list, // this handler will be invoked instead of returning an error, allowing graceful handling of hallucinated tools. // Parameters: // - ctx: The context for the tool call // - name: The name of the non-existent tool // - input: The tool call input generated by llm // Returns: // - string: The response to be returned as if the tool was executed // - error: Any error that occurred during handling UnknownToolsHandler func(ctx context.Context, name, input string) (string, error) // ExecuteSequentially determines whether tool calls should be executed sequentially (in order) or in parallel. // When set to true, tool calls will be executed one after another in the order they appear in the input message. // When set to false (default), tool calls will be executed in parallel. ExecuteSequentially bool // ToolArgumentsHandler allows handling of tool arguments before execution. // When provided, this function will be called for each tool call to process the arguments. // Parameters: // - ctx: The context for the tool call // - name: The name of the tool being called // - arguments: The original arguments string for the tool // Returns: // - string: The processed arguments string to be used for tool execution // - error: Any error that occurred during preprocessing ToolArgumentsHandler func(ctx context.Context, name, arguments string) (string, error) // ToolCallMiddlewares configures middleware for tool calls. // Each element can contain Invokable and/or Streamable middleware. // Invokable middleware only applies to tools implementing InvokableTool interface. // Streamable middleware only applies to tools implementing StreamableTool interface. ToolCallMiddlewares []ToolMiddleware } // NewToolNode creates a new ToolsNode. // e.g. // // conf := &ToolsNodeConfig{ // Tools: []tool.BaseTool{invokableTool1, streamableTool2}, // } // toolsNode, err := NewToolNode(ctx, conf) func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) { var middlewares []InvokableToolMiddleware var streamMiddlewares []StreamableToolMiddleware var enhancedInvokableMiddlewares []EnhancedInvokableToolMiddleware var enhancedStreamableMiddlewares []EnhancedStreamableToolMiddleware for _, m := range conf.ToolCallMiddlewares { if m.Invokable != nil { middlewares = append(middlewares, m.Invokable) } if m.Streamable != nil { streamMiddlewares = append(streamMiddlewares, m.Streamable) } if m.EnhancedInvokable != nil { enhancedInvokableMiddlewares = append(enhancedInvokableMiddlewares, m.EnhancedInvokable) } if m.EnhancedStreamable != nil { enhancedStreamableMiddlewares = append(enhancedStreamableMiddlewares, m.EnhancedStreamable) } } tuple, err := convTools(ctx, conf.Tools, middlewares, streamMiddlewares, enhancedInvokableMiddlewares, enhancedStreamableMiddlewares) if err != nil { return nil, err } return &ToolsNode{ tuple: tuple, unknownToolHandler: conf.UnknownToolsHandler, executeSequentially: conf.ExecuteSequentially, toolArgumentsHandler: conf.ToolArgumentsHandler, toolCallMiddlewares: middlewares, streamToolCallMiddlewares: streamMiddlewares, enhancedToolCallMiddlewares: enhancedInvokableMiddlewares, enhancedStreamToolCallMiddlewares: enhancedStreamableMiddlewares, }, nil } // ToolsInterruptAndRerunExtra carries interrupt metadata for ToolsNode reruns. type ToolsInterruptAndRerunExtra struct { // ToolCalls contains all tool calls from the original assistant message. ToolCalls []schema.ToolCall // ExecutedTools maps tool call IDs to their string output for successfully executed standard tools. ExecutedTools map[string]string // ExecutedEnhancedTools maps tool call IDs to their structured multimodal output for successfully executed enhanced tools. ExecutedEnhancedTools map[string]*schema.ToolResult // RerunTools contains the IDs of tool calls that need to be re-executed. RerunTools []string // RerunExtraMap stores additional metadata for each tool call that needs rerun, keyed by tool call ID. RerunExtraMap map[string]any } func init() { schema.RegisterName[*ToolsInterruptAndRerunExtra]("_eino_compose_tools_interrupt_and_rerun_extra") schema.RegisterName[*toolsInterruptAndRerunState]("_eino_compose_tools_interrupt_and_rerun_state") } type toolsInterruptAndRerunState struct { Input *schema.Message ExecutedTools map[string]string ExecutedEnhancedTools map[string]*schema.ToolResult RerunTools []string } type toolsTuple struct { indexes map[string]int meta []*executorMeta endpoints []InvokableToolEndpoint streamEndpoints []StreamableToolEndpoint enhancedInvokableEndpoints []EnhancedInvokableToolEndpoint enhancedStreamableEndpoints []EnhancedStreamableToolEndpoint } func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMiddleware, sms []StreamableToolMiddleware, ems []EnhancedInvokableToolMiddleware, esms []EnhancedStreamableToolMiddleware) (*toolsTuple, error) { ret := &toolsTuple{ indexes: make(map[string]int), meta: make([]*executorMeta, len(tools)), endpoints: make([]InvokableToolEndpoint, len(tools)), streamEndpoints: make([]StreamableToolEndpoint, len(tools)), enhancedInvokableEndpoints: make([]EnhancedInvokableToolEndpoint, len(tools)), enhancedStreamableEndpoints: make([]EnhancedStreamableToolEndpoint, len(tools)), } for idx, bt := range tools { tl, err := bt.Info(ctx) if err != nil { return nil, fmt.Errorf("(NewToolNode) failed to get tool info at idx= %d: %w", idx, err) } toolName := tl.Name var ( st tool.StreamableTool it tool.InvokableTool eiTool tool.EnhancedInvokableTool esTool tool.EnhancedStreamableTool invokable InvokableToolEndpoint streamable StreamableToolEndpoint enhancedInvokable EnhancedInvokableToolEndpoint enhancedStreamable EnhancedStreamableToolEndpoint ok bool meta *executorMeta ) meta = parseExecutorInfoFromComponent(components.ComponentOfTool, bt) if st, ok = bt.(tool.StreamableTool); ok { streamable = wrapStreamToolCall(st, sms, !meta.isComponentCallbackEnabled) } if it, ok = bt.(tool.InvokableTool); ok { invokable = wrapToolCall(it, ms, !meta.isComponentCallbackEnabled) } if eiTool, ok = bt.(tool.EnhancedInvokableTool); ok { enhancedInvokable = wrapEnhancedInvokableToolCall(eiTool, ems, !meta.isComponentCallbackEnabled) } if esTool, ok = bt.(tool.EnhancedStreamableTool); ok { enhancedStreamable = wrapEnhancedStreamableToolCall(esTool, esms, !meta.isComponentCallbackEnabled) } if st == nil && it == nil && eiTool == nil && esTool == nil { return nil, fmt.Errorf("tool %s is not invokable, streamable, enhanced invokable or enhanced streamable", toolName) } if streamable == nil && invokable != nil { streamable = invokableToStreamable(invokable) } if invokable == nil && streamable != nil { invokable = streamableToInvokable(streamable) } if enhancedStreamable == nil && enhancedInvokable != nil { enhancedStreamable = enhancedInvokableToEnhancedStreamable(enhancedInvokable) } if enhancedInvokable == nil && enhancedStreamable != nil { enhancedInvokable = enhancedStreamableToEnhancedInvokable(enhancedStreamable) } ret.indexes[toolName] = idx ret.meta[idx] = meta ret.endpoints[idx] = invokable ret.streamEndpoints[idx] = streamable ret.enhancedInvokableEndpoints[idx] = enhancedInvokable ret.enhancedStreamableEndpoints[idx] = enhancedStreamable } return ret, nil } func wrapToolCall(it tool.InvokableTool, middlewares []InvokableToolMiddleware, needCallback bool) InvokableToolEndpoint { middleware := func(next InvokableToolEndpoint) InvokableToolEndpoint { for i := len(middlewares) - 1; i >= 0; i-- { next = middlewares[i](next) } return next } if needCallback { it = &invokableToolWithCallback{it: it} } return middleware(func(ctx context.Context, input *ToolInput) (*ToolOutput, error) { result, err := it.InvokableRun(ctx, input.Arguments, input.CallOptions...) if err != nil { return nil, err } return &ToolOutput{Result: result}, nil }) } func wrapStreamToolCall(st tool.StreamableTool, middlewares []StreamableToolMiddleware, needCallback bool) StreamableToolEndpoint { middleware := func(next StreamableToolEndpoint) StreamableToolEndpoint { for i := len(middlewares) - 1; i >= 0; i-- { next = middlewares[i](next) } return next } if needCallback { st = &streamableToolWithCallback{st: st} } return middleware(func(ctx context.Context, input *ToolInput) (*StreamToolOutput, error) { result, err := st.StreamableRun(ctx, input.Arguments, input.CallOptions...) if err != nil { return nil, err } return &StreamToolOutput{Result: result}, nil }) } func wrapEnhancedInvokableToolCall(eiTool tool.EnhancedInvokableTool, middlewares []EnhancedInvokableToolMiddleware, needCallback bool) EnhancedInvokableToolEndpoint { middleware := func(next EnhancedInvokableToolEndpoint) EnhancedInvokableToolEndpoint { for i := len(middlewares) - 1; i >= 0; i-- { next = middlewares[i](next) } return next } if needCallback { eiTool = &enhancedInvokableToolWithCallback{eiTool: eiTool} } return middleware(func(ctx context.Context, input *ToolInput) (*EnhancedInvokableToolOutput, error) { result, err := eiTool.InvokableRun(ctx, &schema.ToolArgument{Text: input.Arguments}, input.CallOptions...) if err != nil { return nil, err } return &EnhancedInvokableToolOutput{Result: result}, nil }) } func wrapEnhancedStreamableToolCall(est tool.EnhancedStreamableTool, middlewares []EnhancedStreamableToolMiddleware, needCallback bool) EnhancedStreamableToolEndpoint { middleware := func(next EnhancedStreamableToolEndpoint) EnhancedStreamableToolEndpoint { for i := len(middlewares) - 1; i >= 0; i-- { next = middlewares[i](next) } return next } if needCallback { est = &enhancedStreamableToolWithCallback{est: est} } return middleware(func(ctx context.Context, input *ToolInput) (*EnhancedStreamableToolOutput, error) { result, err := est.StreamableRun(ctx, &schema.ToolArgument{Text: input.Arguments}, input.CallOptions...) if err != nil { return nil, err } return &EnhancedStreamableToolOutput{Result: result}, nil }) } type invokableToolWithCallback struct { it tool.InvokableTool } func (i *invokableToolWithCallback) Info(ctx context.Context) (*schema.ToolInfo, error) { return i.it.Info(ctx) } func (i *invokableToolWithCallback) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { return invokeWithCallbacks(i.it.InvokableRun)(ctx, argumentsInJSON, opts...) } type streamableToolWithCallback struct { st tool.StreamableTool } func (s *streamableToolWithCallback) Info(ctx context.Context) (*schema.ToolInfo, error) { return s.st.Info(ctx) } func (s *streamableToolWithCallback) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { return streamWithCallbacks(s.st.StreamableRun)(ctx, argumentsInJSON, opts...) } type enhancedInvokableToolWithCallback struct { eiTool tool.EnhancedInvokableTool } func (e *enhancedInvokableToolWithCallback) Info(ctx context.Context) (*schema.ToolInfo, error) { return e.eiTool.Info(ctx) } func (e *enhancedInvokableToolWithCallback) InvokableRun(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { return invokeEnhancedWithCallbacks(e.eiTool.InvokableRun)(ctx, toolArgument, opts...) } type enhancedStreamableToolWithCallback struct { est tool.EnhancedStreamableTool } func (e *enhancedStreamableToolWithCallback) Info(ctx context.Context) (*schema.ToolInfo, error) { return e.est.Info(ctx) } func (e *enhancedStreamableToolWithCallback) StreamableRun(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { return streamEnhancedWithCallbacks(e.est.StreamableRun)(ctx, toolArgument, opts...) } func streamableToInvokable(e StreamableToolEndpoint) InvokableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*ToolOutput, error) { so, err := e(ctx, input) if err != nil { return nil, err } o, err := concatStreamReader(so.Result) if err != nil { return nil, fmt.Errorf("failed to concat StreamableTool output message stream: %w", err) } return &ToolOutput{Result: o}, nil } } func invokableToStreamable(e InvokableToolEndpoint) StreamableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*StreamToolOutput, error) { o, err := e(ctx, input) if err != nil { return nil, err } return &StreamToolOutput{Result: schema.StreamReaderFromArray([]string{o.Result})}, nil } } func enhancedStreamableToEnhancedInvokable(e EnhancedStreamableToolEndpoint) EnhancedInvokableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*EnhancedInvokableToolOutput, error) { so, err := e(ctx, input) if err != nil { return nil, err } o, err := concatStreamReader(so.Result) if err != nil { return nil, fmt.Errorf("failed to concat EnhancedStreamableTool output message stream: %w", err) } return &EnhancedInvokableToolOutput{Result: o}, nil } } func enhancedInvokableToEnhancedStreamable(e EnhancedInvokableToolEndpoint) EnhancedStreamableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*EnhancedStreamableToolOutput, error) { o, err := e(ctx, input) if err != nil { return nil, err } return &EnhancedStreamableToolOutput{Result: schema.StreamReaderFromArray([]*schema.ToolResult{o.Result})}, nil } } func invokeEnhancedWithCallbacks(i func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error)) func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { return runWithCallbacks(i, onStart[*schema.ToolArgument], onEnd[*schema.ToolResult], onError) } func streamEnhancedWithCallbacks(s func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error)) func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { return runWithCallbacks(s, onStart[*schema.ToolArgument], onEndWithStreamOutput[*schema.ToolResult], onError) } type toolCallTask struct { // in endpoint InvokableToolEndpoint streamEndpoint StreamableToolEndpoint enhancedInvokableEndpoint EnhancedInvokableToolEndpoint enhancedStreamableEndpoint EnhancedStreamableToolEndpoint meta *executorMeta name string arg string callID string useEnhanced bool // out executed bool output string sOutput *schema.StreamReader[string] enhancedOutput *schema.ToolResult enhancedSOutput *schema.StreamReader[*schema.ToolResult] err error } func (tn *ToolsNode) genToolCallTasks(ctx context.Context, tuple *toolsTuple, input *schema.Message, executedTools map[string]string, executedEnhancedTools map[string]*schema.ToolResult, isStream bool) ([]toolCallTask, error) { if input.Role != schema.Assistant { return nil, fmt.Errorf("expected message role is Assistant, got %s", input.Role) } n := len(input.ToolCalls) if n == 0 { return nil, errors.New("no tool call found in input message") } toolCallTasks := make([]toolCallTask, n) for i := 0; i < n; i++ { toolCall := input.ToolCalls[i] if enhancedResult, executed := executedEnhancedTools[toolCall.ID]; executed { toolCallTasks[i].name = toolCall.Function.Name toolCallTasks[i].arg = toolCall.Function.Arguments toolCallTasks[i].callID = toolCall.ID toolCallTasks[i].executed = true toolCallTasks[i].useEnhanced = true if isStream { toolCallTasks[i].enhancedSOutput = schema.StreamReaderFromArray([]*schema.ToolResult{enhancedResult}) } else { toolCallTasks[i].enhancedOutput = enhancedResult } continue } if result, executed := executedTools[toolCall.ID]; executed { toolCallTasks[i].name = toolCall.Function.Name toolCallTasks[i].arg = toolCall.Function.Arguments toolCallTasks[i].callID = toolCall.ID toolCallTasks[i].executed = true toolCallTasks[i].useEnhanced = false if isStream { toolCallTasks[i].sOutput = schema.StreamReaderFromArray([]string{result}) } else { toolCallTasks[i].output = result } continue } index, ok := tuple.indexes[toolCall.Function.Name] if !ok { if tn.unknownToolHandler == nil { return nil, fmt.Errorf("tool %s not found in toolsNode indexes", toolCall.Function.Name) } toolCallTasks[i] = newUnknownToolTask(toolCall.Function.Name, toolCall.Function.Arguments, toolCall.ID, tn.unknownToolHandler) } else { toolCallTasks[i].meta = tuple.meta[index] toolCallTasks[i].name = toolCall.Function.Name toolCallTasks[i].callID = toolCall.ID if tuple.enhancedInvokableEndpoints[index] != nil && tuple.enhancedStreamableEndpoints[index] != nil { toolCallTasks[i].enhancedInvokableEndpoint = tuple.enhancedInvokableEndpoints[index] toolCallTasks[i].enhancedStreamableEndpoint = tuple.enhancedStreamableEndpoints[index] toolCallTasks[i].useEnhanced = true } else { toolCallTasks[i].endpoint = tuple.endpoints[index] toolCallTasks[i].streamEndpoint = tuple.streamEndpoints[index] toolCallTasks[i].useEnhanced = false } if tn.toolArgumentsHandler != nil { arg, err := tn.toolArgumentsHandler(ctx, toolCall.Function.Name, toolCall.Function.Arguments) if err != nil { return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, toolCall.Function.Arguments, err) } toolCallTasks[i].arg = arg } else { toolCallTasks[i].arg = toolCall.Function.Arguments } } } return toolCallTasks, nil } func newUnknownToolTask(name, arg, callID string, unknownToolHandler func(ctx context.Context, name, input string) (string, error)) toolCallTask { endpoint := func(ctx context.Context, input *ToolInput) (*ToolOutput, error) { result, err := unknownToolHandler(ctx, input.Name, input.Arguments) if err != nil { return nil, err } return &ToolOutput{ Result: result, }, nil } return toolCallTask{ endpoint: endpoint, streamEndpoint: invokableToStreamable(endpoint), meta: &executorMeta{ component: components.ComponentOfTool, isComponentCallbackEnabled: false, componentImplType: "UnknownTool", }, name: name, arg: arg, callID: callID, } } func runToolCallTaskByInvoke(ctx context.Context, task *toolCallTask, opts ...tool.Option) { if task.executed { return } ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{ Name: task.name, Type: task.meta.componentImplType, Component: task.meta.component, }) ctx = setToolCallInfo(ctx, &toolCallInfo{toolCallID: task.callID}) ctx = appendToolAddressSegment(ctx, task.name, task.callID) if task.useEnhanced { enhancedOutput, err := task.enhancedInvokableEndpoint(ctx, &ToolInput{ Name: task.name, Arguments: task.arg, CallID: task.callID, CallOptions: opts, }) if err != nil { task.err = err } else { task.enhancedOutput = enhancedOutput.Result task.executed = true } } else { output, err := task.endpoint(ctx, &ToolInput{ Name: task.name, Arguments: task.arg, CallID: task.callID, CallOptions: opts, }) if err != nil { task.err = err } else { task.output = output.Result task.executed = true } } } func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...tool.Option) { ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{ Name: task.name, Type: task.meta.componentImplType, Component: task.meta.component, }) ctx = setToolCallInfo(ctx, &toolCallInfo{toolCallID: task.callID}) ctx = appendToolAddressSegment(ctx, task.name, task.callID) if task.useEnhanced { enhancedOutput, err := task.enhancedStreamableEndpoint(ctx, &ToolInput{ Name: task.name, Arguments: task.arg, CallID: task.callID, CallOptions: opts, }) if err != nil { task.err = err } else { task.enhancedSOutput = enhancedOutput.Result task.executed = true } } else { output, err := task.streamEndpoint(ctx, &ToolInput{ Name: task.name, Arguments: task.arg, CallID: task.callID, CallOptions: opts, }) if err != nil { task.err = err } else { task.sOutput = output.Result task.executed = true } } } func sequentialRunToolCall(ctx context.Context, run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), tasks []toolCallTask, opts ...tool.Option) { for i := range tasks { if tasks[i].executed { continue } run(ctx, &tasks[i], opts...) } } func parallelRunToolCall(ctx context.Context, run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), tasks []toolCallTask, opts ...tool.Option) { if len(tasks) == 1 { run(ctx, &tasks[0], opts...) return } var wg sync.WaitGroup for i := 1; i < len(tasks); i++ { if tasks[i].executed { continue } wg.Add(1) go func(ctx_ context.Context, t *toolCallTask, opts ...tool.Option) { defer wg.Done() defer func() { panicErr := recover() if panicErr != nil { t.err = safe.NewPanicErr(panicErr, debug.Stack()) } }() run(ctx_, t, opts...) }(ctx, &tasks[i], opts...) } if !tasks[0].executed { run(ctx, &tasks[0], opts...) } wg.Wait() } // Invoke calls the tools and collects the results of invokable tools. // it's parallel if there are multiple tool calls in the input message. func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error) { opt := getToolsNodeOptions(opts...) tuple := tn.tuple if opt.ToolList != nil { var err error tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) if err != nil { return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) } } var executedTools map[string]string var executedEnhancedTools map[string]*schema.ToolResult if wasInterrupted, hasState, tnState := GetInterruptState[*toolsInterruptAndRerunState](ctx); wasInterrupted && hasState { input = tnState.Input if tnState.ExecutedTools != nil { executedTools = tnState.ExecutedTools } if tnState.ExecutedEnhancedTools != nil { executedEnhancedTools = tnState.ExecutedEnhancedTools } } tasks, err := tn.genToolCallTasks(ctx, tuple, input, executedTools, executedEnhancedTools, false) if err != nil { return nil, err } if tn.executeSequentially { sequentialRunToolCall(ctx, runToolCallTaskByInvoke, tasks, opt.ToolOptions...) } else { parallelRunToolCall(ctx, runToolCallTaskByInvoke, tasks, opt.ToolOptions...) } n := len(tasks) output := make([]*schema.Message, n) rerunExtra := &ToolsInterruptAndRerunExtra{ ToolCalls: input.ToolCalls, ExecutedTools: make(map[string]string), ExecutedEnhancedTools: make(map[string]*schema.ToolResult), RerunExtraMap: make(map[string]any), } rerunState := &toolsInterruptAndRerunState{ Input: input, ExecutedTools: make(map[string]string), ExecutedEnhancedTools: make(map[string]*schema.ToolResult), } var errs []error for i := 0; i < n; i++ { if tasks[i].err != nil { info, ok := IsInterruptRerunError(tasks[i].err) if !ok { return nil, fmt.Errorf("failed to invoke tool[name:%s id:%s]: %w", tasks[i].name, tasks[i].callID, tasks[i].err) } rerunExtra.RerunTools = append(rerunExtra.RerunTools, tasks[i].callID) rerunState.RerunTools = append(rerunState.RerunTools, tasks[i].callID) if info != nil { rerunExtra.RerunExtraMap[tasks[i].callID] = info } iErr := WrapInterruptAndRerunIfNeeded(ctx, AddressSegment{ID: tasks[i].callID, Type: AddressSegmentTool}, tasks[i].err) errs = append(errs, iErr) continue } if tasks[i].executed { if tasks[i].useEnhanced { rerunExtra.ExecutedEnhancedTools[tasks[i].callID] = tasks[i].enhancedOutput rerunState.ExecutedEnhancedTools[tasks[i].callID] = tasks[i].enhancedOutput } else { rerunExtra.ExecutedTools[tasks[i].callID] = tasks[i].output rerunState.ExecutedTools[tasks[i].callID] = tasks[i].output } } if len(errs) == 0 { if tasks[i].useEnhanced { output[i] = schema.ToolMessage("", tasks[i].callID, schema.WithToolName(tasks[i].name)) output[i].UserInputMultiContent, err = tasks[i].enhancedOutput.ToMessageInputParts() if err != nil { return nil, err } } else { output[i] = schema.ToolMessage(tasks[i].output, tasks[i].callID, schema.WithToolName(tasks[i].name)) } } } if len(errs) > 0 { return nil, CompositeInterrupt(ctx, rerunExtra, rerunState, errs...) } return output, nil } // Stream calls the tools and collects the results of stream readers. // it's parallel if there are multiple tool calls in the input message. func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) { opt := getToolsNodeOptions(opts...) tuple := tn.tuple if opt.ToolList != nil { var err error tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) if err != nil { return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) } } var executedTools map[string]string var executedEnhancedTools map[string]*schema.ToolResult if wasInterrupted, hasState, tnState := GetInterruptState[*toolsInterruptAndRerunState](ctx); wasInterrupted && hasState { input = tnState.Input if tnState.ExecutedTools != nil { executedTools = tnState.ExecutedTools } if tnState.ExecutedEnhancedTools != nil { executedEnhancedTools = tnState.ExecutedEnhancedTools } } tasks, err := tn.genToolCallTasks(ctx, tuple, input, executedTools, executedEnhancedTools, true) if err != nil { return nil, err } if tn.executeSequentially { sequentialRunToolCall(ctx, runToolCallTaskByStream, tasks, opt.ToolOptions...) } else { parallelRunToolCall(ctx, runToolCallTaskByStream, tasks, opt.ToolOptions...) } n := len(tasks) rerunExtra := &ToolsInterruptAndRerunExtra{ ToolCalls: input.ToolCalls, ExecutedTools: make(map[string]string), ExecutedEnhancedTools: make(map[string]*schema.ToolResult), RerunExtraMap: make(map[string]any), } rerunState := &toolsInterruptAndRerunState{ Input: input, ExecutedTools: make(map[string]string), ExecutedEnhancedTools: make(map[string]*schema.ToolResult), } var errs []error // check rerun for i := 0; i < n; i++ { if tasks[i].err != nil { info, ok := IsInterruptRerunError(tasks[i].err) if !ok { return nil, fmt.Errorf("failed to stream tool call %s: %w", tasks[i].callID, tasks[i].err) } rerunExtra.RerunTools = append(rerunExtra.RerunTools, tasks[i].callID) rerunState.RerunTools = append(rerunState.RerunTools, tasks[i].callID) if info != nil { rerunExtra.RerunExtraMap[tasks[i].callID] = info } iErr := WrapInterruptAndRerunIfNeeded(ctx, AddressSegment{ID: tasks[i].callID, Type: AddressSegmentTool}, tasks[i].err) errs = append(errs, iErr) continue } } if len(errs) > 0 { // concat and save tool output for _, t := range tasks { if t.executed { if t.useEnhanced { eo, err_ := concatStreamReader(t.enhancedSOutput) if err_ != nil { return nil, fmt.Errorf("failed to concat enhanced tool[name:%s id:%s]'s stream output: %w", t.name, t.callID, err_) } rerunExtra.ExecutedEnhancedTools[t.callID] = eo rerunState.ExecutedEnhancedTools[t.callID] = eo } else { o, err_ := concatStreamReader(t.sOutput) if err_ != nil { return nil, fmt.Errorf("failed to concat tool[name:%s id:%s]'s stream output: %w", t.name, t.callID, err_) } rerunExtra.ExecutedTools[t.callID] = o rerunState.ExecutedTools[t.callID] = o } } } return nil, CompositeInterrupt(ctx, rerunExtra, rerunState, errs...) } // common return sOutput := make([]*schema.StreamReader[[]*schema.Message], n) for i := 0; i < n; i++ { index := i callID := tasks[i].callID callName := tasks[i].name if tasks[i].useEnhanced { cvt := func(tr *schema.ToolResult) ([]*schema.Message, error) { ret := make([]*schema.Message, n) ret[index] = schema.ToolMessage("", callID, schema.WithToolName(callName)) ret[index].UserInputMultiContent, err = tr.ToMessageInputParts() if err != nil { return nil, err } return ret, nil } sOutput[i] = schema.StreamReaderWithConvert(tasks[i].enhancedSOutput, cvt) } else { cvt := func(s string) ([]*schema.Message, error) { ret := make([]*schema.Message, n) ret[index] = schema.ToolMessage(s, callID, schema.WithToolName(callName)) return ret, nil } sOutput[i] = schema.StreamReaderWithConvert(tasks[i].sOutput, cvt) } } return schema.MergeStreamReaders(sOutput), nil } // GetType returns the component type string for the Tools node. func (tn *ToolsNode) GetType() string { return "" } func getToolsNodeOptions(opts ...ToolsNodeOption) *toolsNodeOptions { o := &toolsNodeOptions{ ToolOptions: make([]tool.Option, 0), } for _, opt := range opts { opt(o) } return o } type toolCallInfoKey struct{} type toolCallInfo struct { toolCallID string } func setToolCallInfo(ctx context.Context, toolCallInfo *toolCallInfo) context.Context { return context.WithValue(ctx, toolCallInfoKey{}, toolCallInfo) } // GetToolCallID gets the current tool call id from the context. func GetToolCallID(ctx context.Context) string { v := ctx.Value(toolCallInfoKey{}) if v == nil { return "" } info, ok := v.(*toolCallInfo) if !ok { return "" } return info.toolCallID } ================================================ FILE: compose/tool_node_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "io" "strconv" "strings" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) const ( toolNameOfUserCompany = "user_company" toolIDOfUserCompany = "call_TRZhlagwBS0LpWbWPeZOvIXc" toolNameOfUserSalary = "user_salary" toolIDOfUserSalary = "call_AqfoRW6fuF98k0o7696k2nzm" ) func TestToolsNode(t *testing.T) { var err error ctx := context.Background() userCompanyToolInfo := &schema.ToolInfo{ Name: toolNameOfUserCompany, Desc: "Query user's company and position information based on user's name and email", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Type: "string", Desc: "User's name", }, "email": { Type: "string", Desc: "User's email", }, }), } userSalaryToolInfo := &schema.ToolInfo{ Name: toolNameOfUserSalary, Desc: "Query user's salary information based on user's name and email", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Type: "string", Desc: "User's name", }, "email": { Type: "string", Desc: "User's email", }, }), } t.Run("success", func(t *testing.T) { const ( nodeOfTools = "tools" nodeOfModel = "model" ) g := NewGraph[[]*schema.Message, []*schema.Message]() err = g.AddChatModelNode(nodeOfModel, &mockIntentChatModel{}) assert.NoError(t, err) ui := newTool(userCompanyToolInfo, queryUserCompany) us := newStreamableTool(userSalaryToolInfo, queryUserSalary) toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{ui, us}, }) assert.NoError(t, err) err = g.AddToolsNode(nodeOfTools, toolsNode) assert.NoError(t, err) err = g.AddEdge(START, nodeOfModel) assert.NoError(t, err) err = g.AddEdge(nodeOfModel, nodeOfTools) assert.NoError(t, err) err = g.AddEdge(nodeOfTools, END) assert.NoError(t, err) r, err := g.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, []*schema.Message{}) assert.NoError(t, err) msg := findMsgByToolCallID(out, toolIDOfUserCompany) assert.Equal(t, toolIDOfUserCompany, msg.ToolCallID) assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","gender":"male","company":"bytedance","position":"CEO"}`, msg.Content) msg = findMsgByToolCallID(out, toolIDOfUserSalary) assert.Equal(t, toolIDOfUserSalary, msg.ToolCallID) assert.Contains(t, msg.Content, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":5000}{"user_id":"zhangsan-zhangsan@bytedance.com","salary":3000}{"user_id":"zhangsan-zhangsan@bytedance.com","salary":2000}`) // 测试流式调用 reader, err := r.Stream(ctx, []*schema.Message{}) assert.NoError(t, err) loops := 0 userSalaryTimes := 0 defer reader.Close() var arrMsgs [][]*schema.Message for ; loops < 10; loops++ { msgs, err := reader.Recv() if err == io.EOF { break } arrMsgs = append(arrMsgs, msgs) assert.NoError(t, err) assert.Len(t, msgs, 2) if msg := findMsgByToolCallID(out, toolIDOfUserCompany); msg != nil { assert.Equal(t, schema.Tool, msg.Role) assert.Equal(t, toolIDOfUserCompany, msg.ToolCallID) assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","gender":"male","company":"bytedance","position":"CEO"}`, msg.Content) } else if msg := findMsgByToolCallID(out, toolIDOfUserSalary); msg != nil { assert.Equal(t, schema.Tool, msg.Role) assert.Equal(t, toolIDOfUserSalary, msg.ToolCallID) switch userSalaryTimes { case 0: assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":5000}`, msg.Content) case 1: assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":3000}`, msg.Content) case 2: assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":2000}`, msg.Content) } userSalaryTimes++ } else { assert.Fail(t, "unexpected tool name") } } assert.Equal(t, 4, loops) msgs, err_ := schema.ConcatMessageArray(arrMsgs) assert.NoError(t, err_) msg = findMsgByToolCallID(msgs, toolIDOfUserCompany) msg = findMsgByToolCallID(msgs, toolIDOfUserSalary) sr, sw := schema.Pipe[[]*schema.Message](2) sw.Send([]*schema.Message{ { Role: schema.User, Content: `hi, how are you`, }, }, nil) sw.Send([]*schema.Message{ { Role: schema.User, Content: `i'm fine'`, }, }, nil) sw.Close() reader, err = r.Transform(ctx, sr) assert.NoError(t, err) defer reader.Close() loops = 0 userSalaryTimes = 0 for ; loops < 10; loops++ { msgs, err := reader.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Len(t, msgs, 2) if msg := findMsgByToolCallID(out, toolIDOfUserCompany); msg != nil { assert.Equal(t, schema.Tool, msg.Role) assert.Equal(t, toolIDOfUserCompany, msg.ToolCallID) assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","gender":"male","company":"bytedance","position":"CEO"}`, msg.Content) } else if msg := findMsgByToolCallID(out, toolIDOfUserSalary); msg != nil { assert.Equal(t, schema.Tool, msg.Role) assert.Equal(t, toolIDOfUserSalary, msg.ToolCallID) switch userSalaryTimes { case 0: assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":5000}`, msg.Content) case 1: assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":3000}`, msg.Content) case 2: assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":2000}`, msg.Content) } userSalaryTimes++ } else { assert.Fail(t, "unexpected tool name") } } assert.Equal(t, 4, loops) }) t.Run("order_consistency", func(t *testing.T) { // Create a ToolsNode with multiple tools ui := newTool(userCompanyToolInfo, queryUserCompany) us := newTool(userSalaryToolInfo, queryUserSalary) toolsNode, err_ := NewToolNode(context.Background(), &ToolsNodeConfig{ Tools: []tool.BaseTool{ui, us}, }) assert.NoError(t, err_) // Create an input message with multiple tool calls in a specific order input := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserSalary, Function: schema.FunctionCall{ Name: toolNameOfUserSalary, Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, }, }, { ID: toolIDOfUserCompany, Function: schema.FunctionCall{ Name: toolNameOfUserCompany, Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, }, }, }, } // Invoke the ToolsNode output, err_ := toolsNode.Invoke(context.Background(), input) assert.NoError(t, err_) // Verify the order of output messages matches the order of input tool calls assert.Equal(t, 2, len(output)) assert.Equal(t, toolIDOfUserSalary, output[0].ToolCallID) assert.Equal(t, toolIDOfUserCompany, output[1].ToolCallID) // Test with Stream method as well streamer, err_ := toolsNode.Stream(context.Background(), input) assert.NoError(t, err_) defer streamer.Close() // Collect all stream outputs var streamOutputs [][]*schema.Message for { chunk, err__ := streamer.Recv() if err__ == io.EOF { break } assert.NoError(t, err__) streamOutputs = append(streamOutputs, chunk) } // Verify each chunk maintains the correct order for _, chunk := range streamOutputs { if chunk[0] != nil { assert.Equal(t, toolIDOfUserSalary, chunk[0].ToolCallID) } if chunk[1] != nil { assert.Equal(t, toolIDOfUserCompany, chunk[1].ToolCallID) } } // Concatenate all stream outputs and verify final result concatenated, err_ := schema.ConcatMessageArray(streamOutputs) assert.NoError(t, err_) assert.Equal(t, 2, len(concatenated)) assert.Equal(t, toolIDOfUserSalary, concatenated[0].ToolCallID) assert.Equal(t, toolIDOfUserCompany, concatenated[1].ToolCallID) }) } type userCompanyRequest struct { Name string `json:"name"` Email string `json:"email"` } type userCompanyResponse struct { UserID string `json:"user_id"` Gender string `json:"gender"` Company string `json:"company"` Position string `json:"position"` } func queryUserCompany(ctx context.Context, req *userCompanyRequest) (resp *userCompanyResponse, err error) { callID := GetToolCallID(ctx) if callID != toolIDOfUserCompany { return nil, fmt.Errorf("invalid tool call id= %s", callID) } return &userCompanyResponse{ UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), Gender: "male", Company: "bytedance", Position: "CEO", }, nil } type userSalaryRequest struct { Name string `json:"name"` Email string `json:"email"` } type userSalaryResponse struct { UserID string `json:"user_id"` Salary int `json:"salary"` } func queryUserSalary(ctx context.Context, req *userSalaryRequest) (resp *schema.StreamReader[*userSalaryResponse], err error) { callID := GetToolCallID(ctx) if callID != toolIDOfUserSalary { return nil, fmt.Errorf("invalid tool call id= %s", callID) } sr, sw := schema.Pipe[*userSalaryResponse](10) sw.Send(&userSalaryResponse{ UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), Salary: 5000, }, nil) sw.Send(&userSalaryResponse{ UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), Salary: 3000, }, nil) sw.Send(&userSalaryResponse{ UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), Salary: 2000, }, nil) sw.Close() return sr, nil } type mockIntentChatModel struct{} func (m *mockIntentChatModel) BindTools(tools []*schema.ToolInfo) error { return nil } func (m *mockIntentChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { return &schema.Message{ Role: schema.Assistant, Content: "", ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserCompany, Function: schema.FunctionCall{ Name: toolNameOfUserCompany, Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, }, }, { ID: toolIDOfUserSalary, Function: schema.FunctionCall{ Name: toolNameOfUserSalary, Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, }, }, }, }, nil } func (m *mockIntentChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](2) sw.Send(&schema.Message{ Role: schema.Assistant, Content: "", ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserCompany, Function: schema.FunctionCall{ Name: toolNameOfUserCompany, Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, }, }, }, }, nil) sw.Send(&schema.Message{ Role: schema.Assistant, Content: "", ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserSalary, Function: schema.FunctionCall{ Name: toolNameOfUserSalary, Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, }, }, }, }, nil) sw.Close() return sr, nil } func TestToolsNodeOptions(t *testing.T) { ctx := context.Background() t.Run("tool_option", func(t *testing.T) { g := NewGraph[*schema.Message, []*schema.Message]() mt := &mockTool{} tn, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{mt}, }) assert.NoError(t, err) err = g.AddToolsNode("tools", tn) assert.NoError(t, err) err = g.AddEdge(START, "tools") assert.NoError(t, err) err = g.AddEdge("tools", END) assert.NoError(t, err) r, err := g.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserCompany, Function: schema.FunctionCall{ Name: "mock_tool", Arguments: `{"name": "jack"}`, }, }, }, }, WithToolsNodeOption(WithToolOption(WithAge(10)))) assert.NoError(t, err) assert.Len(t, out, 1) assert.JSONEq(t, `{"echo": "jack: 10"}`, out[0].Content) outMessages := make([][]*schema.Message, 0) outStream, err := r.Stream(ctx, &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserCompany, Function: schema.FunctionCall{ Name: "mock_tool", Arguments: `{"name": "jack"}`, }, }, }, }, WithToolsNodeOption(WithToolOption(WithAge(10)))) assert.NoError(t, err) for { msgs, err := outStream.Recv() if err == io.EOF { break } assert.NoError(t, err) outMessages = append(outMessages, msgs) } outStream.Close() msgs, err := internal.ConcatItems(outMessages) assert.NoError(t, err) assert.Len(t, msgs, 1) assert.JSONEq(t, `{"echo":"jack: 10"}`, msgs[0].Content) }) t.Run("tool_list", func(t *testing.T) { g := NewGraph[*schema.Message, []*schema.Message]() mt := &mockTool{} tn, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{}, }) assert.NoError(t, err) err = g.AddToolsNode("tools", tn) assert.NoError(t, err) err = g.AddEdge(START, "tools") assert.NoError(t, err) err = g.AddEdge("tools", END) assert.NoError(t, err) r, err := g.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserCompany, Function: schema.FunctionCall{ Name: "mock_tool", Arguments: `{"name": "jack"}`, }, }, }, }, WithToolsNodeOption(WithToolList(mt), WithToolOption(WithAge(10)))) assert.NoError(t, err) assert.Len(t, out, 1) assert.JSONEq(t, `{"echo": "jack: 10"}`, out[0].Content) outMessages := make([][]*schema.Message, 0) outStream, err := r.Stream(ctx, &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { ID: toolIDOfUserCompany, Function: schema.FunctionCall{ Name: "mock_tool", Arguments: `{"name": "jack"}`, }, }, }, }, WithToolsNodeOption(WithToolList(mt), WithToolOption(WithAge(10)))) assert.NoError(t, err) for { msgs, err := outStream.Recv() if err == io.EOF { break } assert.NoError(t, err) outMessages = append(outMessages, msgs) } outStream.Close() msgs, err := internal.ConcatItems(outMessages) assert.NoError(t, err) assert.Len(t, msgs, 1) assert.JSONEq(t, `{"echo":"jack: 10"}`, msgs[0].Content) }) } func findMsgByToolCallID(msgs []*schema.Message, toolCallID string) *schema.Message { for _, msg := range msgs { if msg.ToolCallID == toolCallID { return msg } } return nil } type mockToolOptions struct { Age int } func WithAge(age int) tool.Option { return tool.WrapImplSpecificOptFn(func(o *mockToolOptions) { o.Age = age }) } type mockToolRequest struct { Name string `json:"name"` } type mockToolResponse struct { Echo string `json:"echo"` } type mockTool struct{} func (m *mockTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "mock_tool", Desc: "mock tool", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Type: "string", Desc: "name", Required: true, }, }), }, nil } func (m *mockTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { opt := tool.GetImplSpecificOptions(&mockToolOptions{}, opts...) req := &mockToolRequest{} if e := sonic.UnmarshalString(argumentsInJSON, req); e != nil { return "", e } resp := &mockToolResponse{ Echo: fmt.Sprintf("%v: %v", req.Name, opt.Age), } return sonic.MarshalString(resp) } func (m *mockTool) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { sr, sw := schema.Pipe[string](1) go func() { defer sw.Close() opt := tool.GetImplSpecificOptions(&mockToolOptions{}, opts...) req := &mockToolRequest{} if e := sonic.UnmarshalString(argumentsInJSON, req); e != nil { sw.Send("", e) return } resp := mockToolResponse{ Echo: fmt.Sprintf("%v: %v", req.Name, opt.Age), } output, err := sonic.MarshalString(resp) if err != nil { sw.Send("", err) return } for i := 0; i < len(output); i++ { sw.Send(string(output[i]), nil) } }() return sr, nil } func TestUnknownTool(t *testing.T) { ctx := context.Background() tn, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: nil, UnknownToolsHandler: func(ctx context.Context, name, input string) (string, error) { return "unknown", nil }, }) assert.NoError(t, err) input := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { ID: "1", Function: schema.FunctionCall{ Name: "unknown1", Arguments: `arg1`, }, }, { ID: "2", Function: schema.FunctionCall{ Name: "unknown2", Arguments: `arg2`, }, }, }, } expected := []*schema.Message{ { Role: schema.Tool, Content: "unknown", ToolCallID: "1", ToolName: "unknown1", }, { Role: schema.Tool, Content: "unknown", ToolCallID: "2", ToolName: "unknown2", }, } result, err := tn.Invoke(ctx, input) assert.NoError(t, err) assert.Equal(t, expected, result) streamResult, err := tn.Stream(ctx, input) assert.NoError(t, err) result = make([]*schema.Message, 2) for { chunk, err := streamResult.Recv() if err == io.EOF { break } assert.NoError(t, err) for i := range chunk { if chunk[i] != nil { result[i] = chunk[i] } } } assert.Equal(t, expected, result) } func TestToolRerun(t *testing.T) { type myToolRerunState struct { In *schema.Message } schema.Register[myToolRerunState]() tc := []schema.ToolCall{ { ID: "3", Function: schema.FunctionCall{ Name: "tool3", Arguments: "input", }, }, { ID: "4", Function: schema.FunctionCall{ Name: "tool4", Arguments: "input", }, }, { ID: "1", Function: schema.FunctionCall{ Name: "tool1", Arguments: "input", }, }, { ID: "2", Function: schema.FunctionCall{ Name: "tool2", Arguments: "input", }, }, } g := NewGraph[*schema.Message, string](WithGenLocalState(func(ctx context.Context) (state *myToolRerunState) { return &myToolRerunState{In: &schema.Message{Role: schema.Assistant, ToolCalls: tc}} })) ctx := context.Background() tn, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{&myTool1{}, &myTool2{}, &myTool3{t: t}, &myTool4{t: t}}, }) assert.NoError(t, err) assert.NoError(t, g.AddToolsNode("tool node", tn)) assert.NoError(t, g.AddLambdaNode("lambda", InvokableLambda(func(ctx context.Context, input []*schema.Message) (output string, err error) { contents := make([]string, len(input)) for _, m := range input { callID := m.ToolCallID callIDInt, err := strconv.Atoi(callID) if err != nil { return "", err } contents[callIDInt-1] = m.Content } sb := strings.Builder{} for _, m := range contents { sb.WriteString(m) } return sb.String(), nil }))) assert.NoError(t, g.AddEdge(START, "tool node")) assert.NoError(t, g.AddEdge("tool node", "lambda")) assert.NoError(t, g.AddEdge("lambda", END)) r, err := g.Compile(ctx, WithCheckPointStore(&inMemoryStore{m: map[string][]byte{}})) assert.NoError(t, err) _, err = r.Stream(ctx, &schema.Message{Role: schema.Assistant, ToolCalls: tc}, WithCheckPointID("1")) info, ok := ExtractInterruptInfo(err) assert.True(t, ok) assert.Equal(t, []string{"tool node"}, info.RerunNodes) assert.Equal(t, &ToolsInterruptAndRerunExtra{ ToolCalls: tc, RerunTools: []string{"1", "2"}, RerunExtraMap: map[string]any{"1": "tool1 rerun extra", "2": "tool2 rerun extra"}, ExecutedTools: map[string]string{ "3": "tool3 input: input", "4": "tool4 input: input", }, ExecutedEnhancedTools: make(map[string]*schema.ToolResult), }, info.RerunNodesExtra["tool node"]) sr, err := r.Stream(ctx, nil, WithCheckPointID("1")) assert.NoError(t, err) result, err := concatStreamReader(sr) assert.NoError(t, err) assert.Equal(t, "tool1 input: inputtool2 input: inputtool3 input: inputtool4 input: input", result) } func TestToolMiddleware(t *testing.T) { ctx := context.Background() t3 := &myTool3{t: t} t4 := &myTool4{t: t} tn, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{t3, t4}, ToolCallMiddlewares: []ToolMiddleware{ { Invokable: func(endpoint InvokableToolEndpoint) InvokableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*ToolOutput, error) { _, err := endpoint(ctx, input) if err != nil { return nil, err } return &ToolOutput{Result: "middleware1"}, nil } }, }, { Streamable: func(endpoint StreamableToolEndpoint) StreamableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*StreamToolOutput, error) { _, err := endpoint(ctx, input) if err != nil { return nil, err } return &StreamToolOutput{Result: schema.StreamReaderFromArray([]string{"middleware2"})}, nil } }, }, }, }) assert.NoError(t, err) messages, err := tn.Invoke(ctx, schema.AssistantMessage("", []schema.ToolCall{ {ID: "1", Function: schema.FunctionCall{Name: "tool3", Arguments: ""}}, {ID: "2", Function: schema.FunctionCall{Name: "tool4", Arguments: ""}}, })) assert.NoError(t, err) assert.Len(t, messages, 2) assert.Equal(t, "middleware1", messages[0].Content) assert.Equal(t, "middleware2", messages[1].Content) t3.times, t4.times = 0, 0 // reset t3 t4 messageStreams, err := tn.Stream(ctx, schema.AssistantMessage("", []schema.ToolCall{ {ID: "1", Function: schema.FunctionCall{Name: "tool3", Arguments: ""}}, {ID: "2", Function: schema.FunctionCall{Name: "tool4", Arguments: ""}}, })) assert.NoError(t, err) var messageArray [][]*schema.Message for { chunk, err := messageStreams.Recv() if err == io.EOF { break } assert.NoError(t, err) messageArray = append(messageArray, chunk) } messages, err = schema.ConcatMessageArray(messageArray) assert.Len(t, messages, 2) assert.Equal(t, "middleware1", messages[0].Content) assert.Equal(t, "middleware2", messages[1].Content) } type myTool1 struct { times uint } func (m *myTool1) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "tool1"}, nil } func (m *myTool1) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { if m.times == 0 { m.times++ return "", tool.Interrupt(ctx, "tool1 rerun extra") } return "tool1 input: " + argumentsInJSON, nil } type myTool2 struct { times uint } func (m *myTool2) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "tool2"}, nil } func (m *myTool2) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { if m.times == 0 { m.times++ return nil, tool.Interrupt(ctx, "tool2 rerun extra") } return schema.StreamReaderFromArray([]string{"tool2 input: ", argumentsInJSON}), nil } type myTool3 struct { t *testing.T times int } func (m *myTool3) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "tool3"}, nil } func (m *myTool3) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { assert.Equal(m.t, 0, m.times) m.times++ return "tool3 input: " + argumentsInJSON, nil } type myTool4 struct { t *testing.T times int } func (m *myTool4) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "tool4"}, nil } func (m *myTool4) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { assert.Equal(m.t, 0, m.times) m.times++ return schema.StreamReaderFromArray([]string{"tool4 input: ", argumentsInJSON}), nil } func newTool[I, O any](info *schema.ToolInfo, f func(ctx context.Context, in I) (O, error)) tool.InvokableTool { return &invokableTool[I, O]{ info: info, fn: f, } } type invokableTool[I, O any] struct { info *schema.ToolInfo fn func(ctx context.Context, in I) (O, error) } func (f *invokableTool[I, O]) Info(ctx context.Context) (*schema.ToolInfo, error) { return f.info, nil } func (f *invokableTool[I, O]) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { t := generic.NewInstance[I]() err := sonic.UnmarshalString(argumentsInJSON, t) if err != nil { return "", err } o, err := f.fn(ctx, t) if err != nil { return "", err } return sonic.MarshalString(o) } func newStreamableTool[I, O any](info *schema.ToolInfo, f func(ctx context.Context, in I) (*schema.StreamReader[O], error)) tool.StreamableTool { return &streamableTool[I, O]{ info: info, fn: f, } } type streamableTool[I, O any] struct { info *schema.ToolInfo fn func(ctx context.Context, in I) (*schema.StreamReader[O], error) } func (f *streamableTool[I, O]) Info(ctx context.Context) (*schema.ToolInfo, error) { return f.info, nil } func (f *streamableTool[I, O]) StreamableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (*schema.StreamReader[string], error) { t := generic.NewInstance[I]() err := sonic.UnmarshalString(argumentsInJSON, t) if err != nil { return nil, err } sr, err := f.fn(ctx, t) if err != nil { return nil, err } return schema.StreamReaderWithConvert(sr, func(o O) (string, error) { return sonic.MarshalString(o) }), nil } type enhancedInvokableTool struct { info *schema.ToolInfo fn func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) } func (e *enhancedInvokableTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return e.info, nil } func (e *enhancedInvokableTool) InvokableRun(ctx context.Context, toolArgument *schema.ToolArgument, _ ...tool.Option) (*schema.ToolResult, error) { return e.fn(ctx, toolArgument) } type enhancedStreamableTool struct { info *schema.ToolInfo fn func(ctx context.Context, input *schema.ToolArgument) (*schema.StreamReader[*schema.ToolResult], error) } func (e *enhancedStreamableTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return e.info, nil } func (e *enhancedStreamableTool) StreamableRun(ctx context.Context, toolArgument *schema.ToolArgument, _ ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { return e.fn(ctx, toolArgument) } func TestEnhancedToolNode(t *testing.T) { ctx := context.Background() enhancedInvokable := &enhancedInvokableTool{ info: &schema.ToolInfo{ Name: "enhanced_invokable_tool", Desc: "test enhanced invokable tool", }, fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) { return &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "invokable result: " + input.Text}, }, }, nil }, } enhancedStreamable := &enhancedStreamableTool{ info: &schema.ToolInfo{ Name: "enhanced_streamable_tool", Desc: "test enhanced streamable tool", }, fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.StreamReader[*schema.ToolResult], error) { results := []*schema.ToolResult{ {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "stream part 1: " + input.Text}}}, {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: " stream part 2"}}}, } return schema.StreamReaderFromArray(results), nil }, } toolNode, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{enhancedInvokable, enhancedStreamable}, }) assert.NoError(t, err) assert.NotNil(t, toolNode) t.Run("enhanced invokable tool", func(t *testing.T) { input := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call1", Function: schema.FunctionCall{ Name: "enhanced_invokable_tool", Arguments: "test input", }, }, }) output, err := toolNode.Invoke(ctx, input) assert.NoError(t, err) assert.Len(t, output, 1) assert.Equal(t, schema.Tool, output[0].Role) assert.Equal(t, "call1", output[0].ToolCallID) }) t.Run("enhanced streamable tool", func(t *testing.T) { input := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call2", Function: schema.FunctionCall{ Name: "enhanced_streamable_tool", Arguments: "test stream", }, }, }) streamReader, err := toolNode.Stream(ctx, input) assert.NoError(t, err) assert.NotNil(t, streamReader) var messages []*schema.Message for { chunk, err := streamReader.Recv() if err != nil { break } if chunk != nil { messages = append(messages, chunk...) } } message, err := schema.ConcatMessages(messages) assert.NoError(t, err) assert.Len(t, messages, 2) assert.Equal(t, schema.Tool, messages[0].Role) assert.Equal(t, "call2", messages[0].ToolCallID) assert.Contains(t, message.UserInputMultiContent[0].Text, "stream part") }) } func TestEnhancedToolConversion(t *testing.T) { ctx := context.Background() enhancedInvokable := &enhancedInvokableTool{ info: &schema.ToolInfo{ Name: "enhanced_only_invokable", Desc: "test enhanced invokable only", }, fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) { return &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Type: schema.ToolPartTypeText, Text: "enhanced: " + input.Text}, }, }, nil }, } toolNode, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{enhancedInvokable}, }) assert.NoError(t, err) t.Run("enhanced invokable auto-converts to streamable", func(t *testing.T) { input := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call1", Function: schema.FunctionCall{ Name: "enhanced_only_invokable", Arguments: "test", }, }, }) streamReader, err := toolNode.Stream(ctx, input) assert.NoError(t, err) assert.NotNil(t, streamReader) var messages []*schema.Message for { chunk, err := streamReader.Recv() if err != nil { break } if chunk != nil { messages = append(messages, chunk...) } } assert.Len(t, messages, 1) }) } func TestEnhancedToolMiddleware(t *testing.T) { ctx := context.Background() var invokableMiddlewareCalled bool var streamableMiddlewareCalled bool enhancedInvokable := &enhancedInvokableTool{ info: &schema.ToolInfo{ Name: "enhanced_tool_with_middleware", Desc: "test enhanced tool with middleware", }, fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) { return &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Text: "result", Type: schema.ToolPartTypeText}, }, }, nil }, } toolNode, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{enhancedInvokable}, ToolCallMiddlewares: []ToolMiddleware{ { EnhancedInvokable: func(next EnhancedInvokableToolEndpoint) EnhancedInvokableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*EnhancedInvokableToolOutput, error) { invokableMiddlewareCalled = true return next(ctx, input) } }, EnhancedStreamable: func(next EnhancedStreamableToolEndpoint) EnhancedStreamableToolEndpoint { return func(ctx context.Context, input *ToolInput) (*EnhancedStreamableToolOutput, error) { streamableMiddlewareCalled = true return next(ctx, input) } }, }, }, }) assert.NoError(t, err) t.Run("enhanced invokable middleware", func(t *testing.T) { invokableMiddlewareCalled = false input := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call1", Function: schema.FunctionCall{ Name: "enhanced_tool_with_middleware", Arguments: "test", }, }, }) _, err := toolNode.Invoke(ctx, input) assert.NoError(t, err) assert.True(t, invokableMiddlewareCalled) }) t.Run("enhanced streamable middleware", func(t *testing.T) { streamableMiddlewareCalled = false input := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call2", Function: schema.FunctionCall{ Name: "enhanced_tool_with_middleware", Arguments: "test", }, }, }) streamReader, err := toolNode.Stream(ctx, input) assert.NoError(t, err) for { _, err := streamReader.Recv() if err != nil { break } } assert.False(t, streamableMiddlewareCalled) }) } func TestEnhancedToolPriority(t *testing.T) { ctx := context.Background() enhancedInvokable := &enhancedInvokableTool{ info: &schema.ToolInfo{ Name: "test_tool", Desc: "test tool with both enhanced and regular", }, fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) { return &schema.ToolResult{ Parts: []schema.ToolOutputPart{ {Text: "enhanced result", Type: schema.ToolPartTypeText}, }, }, nil }, } toolNode, err := NewToolNode(ctx, &ToolsNodeConfig{ Tools: []tool.BaseTool{enhancedInvokable}, }) assert.NoError(t, err) t.Run("enhanced tool is used when available", func(t *testing.T) { input := schema.AssistantMessage("", []schema.ToolCall{ { ID: "call1", Function: schema.FunctionCall{ Name: "test_tool", Arguments: "test", }, }, }) output, err := toolNode.Invoke(ctx, input) assert.NoError(t, err) assert.Len(t, output, 1) assert.Contains(t, output[0].UserInputMultiContent[0].Text, "enhanced result") }) } ================================================ FILE: compose/types.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "github.com/cloudwego/eino/components" ) type component = components.Component // built-in component types in graph node. // it represents the type of the most primitive executable object provided by the user. const ( ComponentOfUnknown component = "Unknown" ComponentOfGraph component = "Graph" ComponentOfWorkflow component = "Workflow" ComponentOfChain component = "Chain" ComponentOfPassthrough component = "Passthrough" ComponentOfToolsNode component = "ToolsNode" ComponentOfLambda component = "Lambda" ) // NodeTriggerMode controls the triggering mode of graph nodes. type NodeTriggerMode string const ( // AnyPredecessor means that the node will be triggered when any of its predecessors is included in the previous completed super step. // Ref:https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/orchestration_design_principles/#runtime-engine AnyPredecessor NodeTriggerMode = "any_predecessor" // AllPredecessor means that the current node will only be triggered when all of its predecessor nodes have finished running. AllPredecessor NodeTriggerMode = "all_predecessor" ) ================================================ FILE: compose/types_composable.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "reflect" ) // AnyGraph the identifiers for composable and compilable Graph[I, O]、Chain[I, O] in Eino. type AnyGraph interface { getGenericHelper() *genericHelper compile(ctx context.Context, options *graphCompileOptions) (*composableRunnable, error) inputType() reflect.Type outputType() reflect.Type component() component } ================================================ FILE: compose/types_lambda.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "github.com/cloudwego/eino/schema" ) // Invoke is the type of the invokable lambda function. type Invoke[I, O, TOption any] func(ctx context.Context, input I, opts ...TOption) (output O, err error) // Stream is the type of the streamable lambda function. type Stream[I, O, TOption any] func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) // Collect is the type of the collectable lambda function. type Collect[I, O, TOption any] func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) // Transform is the type of the transformable lambda function. type Transform[I, O, TOption any] func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) // InvokeWOOpt is the type of the invokable lambda function without options. type InvokeWOOpt[I, O any] func(ctx context.Context, input I) (output O, err error) // StreamWOOpt is the type of the streamable lambda function without options. type StreamWOOpt[I, O any] func(ctx context.Context, input I) (output *schema.StreamReader[O], err error) // CollectWOOpt is the type of the collectable lambda function without options. type CollectWOOpt[I, O any] func(ctx context.Context, input *schema.StreamReader[I]) (output O, err error) // TransformWOOpts is the type of the transformable lambda function without options. type TransformWOOpts[I, O any] func(ctx context.Context, input *schema.StreamReader[I]) (output *schema.StreamReader[O], err error) // Lambda is the node that wraps the user provided lambda function. // It can be used as a node in Graph or Chain (include Parallel and Branch). // Create a Lambda by using AnyLambda/InvokableLambda/StreamableLambda/CollectableLambda/TransformableLambda. // eg. // // lambda := compose.InvokableLambda(func(ctx context.Context, input string) (output string, err error) { // return input, nil // }) type Lambda struct { executor *composableRunnable } type lambdaOpts struct { // same as executorMeta.isComponentCallbackEnabled // indicates whether the executable lambda user provided could execute the callback aspect itself. // if it could, the callback in the corresponding graph node won't be executed anymore enableComponentCallback bool // same as executorMeta.componentImplType // for AnyLambda, the value comes from the user's explicit config // if componentImplType is empty, then the class name or func name in the instance will be inferred, but no guarantee. componentImplType string } // LambdaOpt is the option for creating a Lambda. type LambdaOpt func(o *lambdaOpts) // WithLambdaCallbackEnable enables the callback aspect of the lambda function. func WithLambdaCallbackEnable(y bool) LambdaOpt { return func(o *lambdaOpts) { o.enableComponentCallback = y } } // WithLambdaType sets the type of the lambda function. func WithLambdaType(t string) LambdaOpt { return func(o *lambdaOpts) { o.componentImplType = t } } type unreachableOption struct{} // InvokableLambdaWithOption creates a Lambda with invokable lambda function and options. func InvokableLambdaWithOption[I, O, TOption any](i Invoke[I, O, TOption], opts ...LambdaOpt) *Lambda { return anyLambda(i, nil, nil, nil, opts...) } // InvokableLambda creates a Lambda with invokable lambda function without options. func InvokableLambda[I, O any](i InvokeWOOpt[I, O], opts ...LambdaOpt) *Lambda { f := func(ctx context.Context, input I, opts_ ...unreachableOption) (output O, err error) { return i(ctx, input) } return anyLambda(f, nil, nil, nil, opts...) } // StreamableLambdaWithOption creates a Lambda with streamable lambda function and options. func StreamableLambdaWithOption[I, O, TOption any](s Stream[I, O, TOption], opts ...LambdaOpt) *Lambda { return anyLambda(nil, s, nil, nil, opts...) } // StreamableLambda creates a Lambda with streamable lambda function without options. func StreamableLambda[I, O any](s StreamWOOpt[I, O], opts ...LambdaOpt) *Lambda { f := func(ctx context.Context, input I, opts_ ...unreachableOption) ( output *schema.StreamReader[O], err error) { return s(ctx, input) } return anyLambda(nil, f, nil, nil, opts...) } // CollectableLambdaWithOption creates a Lambda with collectable lambda function and options. func CollectableLambdaWithOption[I, O, TOption any](c Collect[I, O, TOption], opts ...LambdaOpt) *Lambda { return anyLambda(nil, nil, c, nil, opts...) } // CollectableLambda creates a Lambda with collectable lambda function without options. func CollectableLambda[I, O any](c CollectWOOpt[I, O], opts ...LambdaOpt) *Lambda { f := func(ctx context.Context, input *schema.StreamReader[I], opts_ ...unreachableOption) (output O, err error) { return c(ctx, input) } return anyLambda(nil, nil, f, nil, opts...) } // TransformableLambdaWithOption creates a Lambda with transformable lambda function and options. func TransformableLambdaWithOption[I, O, TOption any](t Transform[I, O, TOption], opts ...LambdaOpt) *Lambda { return anyLambda(nil, nil, nil, t, opts...) } // TransformableLambda creates a Lambda with transformable lambda function without options. func TransformableLambda[I, O any](t TransformWOOpts[I, O], opts ...LambdaOpt) *Lambda { f := func(ctx context.Context, input *schema.StreamReader[I], opts_ ...unreachableOption) (output *schema.StreamReader[O], err error) { return t(ctx, input) } return anyLambda(nil, nil, nil, f, opts...) } // AnyLambda creates a Lambda with any lambda function. // you can only implement one or more of the four lambda functions, and the rest use nil. // eg. // // invokeFunc := func(ctx context.Context, input string, opts ...myOption) (output string, err error) { // // ... // } // streamFunc := func(ctx context.Context, input string, opts ...myOption) (output *schema.StreamReader[string], err error) { // // ... // } // // lambda := compose.AnyLambda(invokeFunc, streamFunc, nil, nil) func AnyLambda[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], c Collect[I, O, TOption], t Transform[I, O, TOption], opts ...LambdaOpt) (*Lambda, error) { if i == nil && s == nil && c == nil && t == nil { return nil, fmt.Errorf("needs to have at least one of four lambda types: invoke/stream/collect/transform, got none") } return anyLambda(i, s, c, t, opts...), nil } func anyLambda[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], c Collect[I, O, TOption], t Transform[I, O, TOption], opts ...LambdaOpt) *Lambda { opt := getLambdaOpt(opts...) executor := runnableLambda(i, s, c, t, !opt.enableComponentCallback, ) executor.meta = &executorMeta{ component: ComponentOfLambda, isComponentCallbackEnabled: opt.enableComponentCallback, componentImplType: opt.componentImplType, } return &Lambda{ executor: executor, } } func getLambdaOpt(opts ...LambdaOpt) *lambdaOpts { opt := &lambdaOpts{ enableComponentCallback: false, componentImplType: "", } for _, optFn := range opts { optFn(opt) } return opt } // ToList creates a Lambda that converts input I to a []I. // It's useful when you want to convert a single input to a list of inputs. // eg. // // lambda := compose.ToList[*schema.Message]() // chain := compose.NewChain[[]*schema.Message, []*schema.Message]() // // chain.AddChatModel(chatModel) // chatModel returns *schema.Message, but we need []*schema.Message // chain.AddLambda(lambda) // convert *schema.Message to []*schema.Message func ToList[I any](opts ...LambdaOpt) *Lambda { i := func(ctx context.Context, input I, opts_ ...unreachableOption) (output []I, err error) { return []I{input}, nil } f := func(ctx context.Context, inputS *schema.StreamReader[I], opts_ ...unreachableOption) (outputS *schema.StreamReader[[]I], err error) { return schema.StreamReaderWithConvert(inputS, func(i I) ([]I, error) { return []I{i}, nil }), nil } return anyLambda(i, nil, nil, f, opts...) } // MessageParser creates a lambda that parses a message into an object T, usually used after a chatmodel. // usage: // // parser := schema.NewMessageJSONParser[MyStruct](&schema.MessageJSONParseConfig{ // ParseFrom: schema.MessageParseFromContent, // }) // parserLambda := MessageParser(parser) // // chain := NewChain[*schema.Message, MyStruct]() // chain.AppendChatModel(chatModel) // chain.AppendLambda(parserLambda) // // r, err := chain.Compile(context.Background()) // // // parsed is a MyStruct object // parsed, err := r.Invoke(context.Background(), &schema.Message{ // Role: schema.MessageRoleUser, // Content: "return a json string for my struct", // }) func MessageParser[T any](p schema.MessageParser[T], opts ...LambdaOpt) *Lambda { i := func(ctx context.Context, input *schema.Message, opts_ ...unreachableOption) (output T, err error) { return p.Parse(ctx, input) } opts = append([]LambdaOpt{WithLambdaType("MessageParse")}, opts...) return anyLambda(i, nil, nil, nil, opts...) } ================================================ FILE: compose/types_lambda_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" ) func TestLambda(t *testing.T) { t.Run("InvokableLambda", func(t *testing.T) { ld := InvokableLambdaWithOption( func(ctx context.Context, input string, opts ...any) (output string, err error) { return "good", nil }, WithLambdaCallbackEnable(false), WithLambdaType("ForTest"), ) assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) ld = InvokableLambda( func(ctx context.Context, input string) (output string, err error) { return "good", nil }, WithLambdaCallbackEnable(false), WithLambdaType("ForTest"), ) assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) }) t.Run("StreamableLambda", func(t *testing.T) { ld := StreamableLambdaWithOption( func(ctx context.Context, input string, opts ...any) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](1) sw.Close() return sr, nil }, WithLambdaCallbackEnable(false), WithLambdaType("ForTest"), ) assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) ld = StreamableLambda( func(ctx context.Context, input string) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](1) sw.Close() return sr, nil }, WithLambdaCallbackEnable(false), WithLambdaType("ForTest"), ) assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) }) t.Run("CollectableLambda", func(t *testing.T) { ld := CollectableLambdaWithOption( func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output string, err error) { return "good", nil }, WithLambdaCallbackEnable(true), ) assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "", ld.executor.meta.componentImplType) ld = CollectableLambda( func(ctx context.Context, input *schema.StreamReader[string]) (output string, err error) { return "good", nil }, WithLambdaCallbackEnable(true), ) assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "", ld.executor.meta.componentImplType) }) t.Run("TransformableLambda", func(t *testing.T) { ld := TransformableLambdaWithOption( func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](1) sw.Close() return sr, nil }, WithLambdaCallbackEnable(true), ) assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "", ld.executor.meta.componentImplType) ld = TransformableLambda( func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](1) sw.Close() return sr, nil }, WithLambdaCallbackEnable(true), ) assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "", ld.executor.meta.componentImplType) }) t.Run("AnyLambda", func(t *testing.T) { ld, err := AnyLambda[string, string]( func(ctx context.Context, input string, opts ...any) (output string, err error) { return "good", nil }, func(ctx context.Context, input string, opts ...any) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](1) sw.Close() return sr, nil }, func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output string, err error) { return "good", nil }, func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output *schema.StreamReader[string], err error) { sr, sw := schema.Pipe[string](1) sw.Close() return sr, nil }, WithLambdaCallbackEnable(true), WithLambdaType("ForTest"), ) assert.NoError(t, err) assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) }) } type TestStructForParse struct { ID int `json:"id"` } func TestMessageParser(t *testing.T) { t.Run("parse from content", func(t *testing.T) { parser := schema.NewMessageJSONParser[TestStructForParse](&schema.MessageJSONParseConfig{ ParseFrom: schema.MessageParseFromContent, }) parserLambda := MessageParser(parser) chain := NewChain[*schema.Message, TestStructForParse]() chain.AppendLambda(parserLambda) r, err := chain.Compile(context.Background()) assert.Nil(t, err) parsed, err := r.Invoke(context.Background(), &schema.Message{ Content: `{"id": 1}`, }) assert.Nil(t, err) assert.Equal(t, 1, parsed.ID) }) t.Run("parse from tool call", func(t *testing.T) { parser := schema.NewMessageJSONParser[*TestStructForParse](&schema.MessageJSONParseConfig{ ParseFrom: schema.MessageParseFromToolCall, }) parserLambda := MessageParser(parser) chain := NewChain[*schema.Message, *TestStructForParse]() chain.AppendLambda(parserLambda) r, err := chain.Compile(context.Background()) assert.Nil(t, err) parsed, err := r.Invoke(context.Background(), &schema.Message{ ToolCalls: []schema.ToolCall{ {Function: schema.FunctionCall{Arguments: `{"id": 1}`}}, }, }) assert.Nil(t, err) assert.Equal(t, 1, parsed.ID) }) } ================================================ FILE: compose/utils.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "reflect" "github.com/cloudwego/eino/callbacks" icb "github.com/cloudwego/eino/internal/callbacks" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) type on[T any] func(context.Context, T) (context.Context, T) func onStart[T any](ctx context.Context, input T) (context.Context, T) { return icb.On(ctx, input, icb.OnStartHandle[T], callbacks.TimingOnStart, true) } func onEnd[T any](ctx context.Context, output T) (context.Context, T) { return icb.On(ctx, output, icb.OnEndHandle[T], callbacks.TimingOnEnd, false) } func onStartWithStreamInput[T any](ctx context.Context, input *schema.StreamReader[T]) ( context.Context, *schema.StreamReader[T]) { return icb.On(ctx, input, icb.OnStartWithStreamInputHandle[T], callbacks.TimingOnStartWithStreamInput, true) } func genericOnStartWithStreamInputHandle(ctx context.Context, input streamReader, runInfo *icb.RunInfo, handlers []icb.Handler) (context.Context, streamReader) { handlers = generic.Reverse(handlers) cpy := input.copy handle := func(ctx context.Context, handler icb.Handler, in streamReader) context.Context { in_, ok := unpackStreamReader[icb.CallbackInput](in) if !ok { panic("impossible") } return handler.OnStartWithStreamInput(ctx, runInfo, in_) } return icb.OnWithStreamHandle(ctx, input, handlers, cpy, handle) } func genericOnStartWithStreamInput(ctx context.Context, input streamReader) (context.Context, streamReader) { return icb.On(ctx, input, genericOnStartWithStreamInputHandle, callbacks.TimingOnStartWithStreamInput, true) } func onEndWithStreamOutput[T any](ctx context.Context, output *schema.StreamReader[T]) ( context.Context, *schema.StreamReader[T]) { return icb.On(ctx, output, icb.OnEndWithStreamOutputHandle[T], callbacks.TimingOnEndWithStreamOutput, false) } func genericOnEndWithStreamOutputHandle(ctx context.Context, output streamReader, runInfo *icb.RunInfo, handlers []icb.Handler) (context.Context, streamReader) { cpy := output.copy handle := func(ctx context.Context, handler icb.Handler, out streamReader) context.Context { out_, ok := unpackStreamReader[icb.CallbackOutput](out) if !ok { panic("impossible") } return handler.OnEndWithStreamOutput(ctx, runInfo, out_) } return icb.OnWithStreamHandle(ctx, output, handlers, cpy, handle) } func genericOnEndWithStreamOutput(ctx context.Context, output streamReader) (context.Context, streamReader) { return icb.On(ctx, output, genericOnEndWithStreamOutputHandle, callbacks.TimingOnEndWithStreamOutput, false) } func onError(ctx context.Context, err error) (context.Context, error) { return icb.On(ctx, err, icb.OnErrorHandle, callbacks.TimingOnError, false) } func runWithCallbacks[I, O, TOption any](r func(context.Context, I, ...TOption) (O, error), onStart on[I], onEnd on[O], onError on[error]) func(context.Context, I, ...TOption) (O, error) { return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { ctx, input = onStart(ctx, input) output, err = r(ctx, input, opts...) if err != nil { ctx, err = onError(ctx, err) return output, err } ctx, output = onEnd(ctx, output) return output, nil } } func invokeWithCallbacks[I, O, TOption any](i Invoke[I, O, TOption]) Invoke[I, O, TOption] { return runWithCallbacks(i, onStart[I], onEnd[O], onError) } func onGraphStart(ctx context.Context, input any, isStream bool) (context.Context, any) { if isStream { return genericOnStartWithStreamInput(ctx, input.(streamReader)) } return onStart(ctx, input) } func onGraphEnd(ctx context.Context, output any, isStream bool) (context.Context, any) { if isStream { return genericOnEndWithStreamOutput(ctx, output.(streamReader)) } return onEnd(ctx, output) } func onGraphError(ctx context.Context, err error) (context.Context, error) { return onError(ctx, err) } func streamWithCallbacks[I, O, TOption any](s Stream[I, O, TOption]) Stream[I, O, TOption] { return runWithCallbacks(s, onStart[I], onEndWithStreamOutput[O], onError) } func collectWithCallbacks[I, O, TOption any](c Collect[I, O, TOption]) Collect[I, O, TOption] { return runWithCallbacks(c, onStartWithStreamInput[I], onEnd[O], onError) } func transformWithCallbacks[I, O, TOption any](t Transform[I, O, TOption]) Transform[I, O, TOption] { return runWithCallbacks(t, onStartWithStreamInput[I], onEndWithStreamOutput[O], onError) } func initGraphCallbacks(ctx context.Context, info *nodeInfo, meta *executorMeta, opts ...Option) context.Context { ri := &callbacks.RunInfo{} if meta != nil { ri.Component = meta.component ri.Type = meta.componentImplType } if info != nil { ri.Name = info.name } var cbs []callbacks.Handler for i := range opts { if len(opts[i].handler) != 0 && len(opts[i].paths) == 0 { cbs = append(cbs, opts[i].handler...) } } if len(cbs) == 0 { return icb.ReuseHandlers(ctx, ri) } return icb.AppendHandlers(ctx, ri, cbs...) } func initNodeCallbacks(ctx context.Context, key string, info *nodeInfo, meta *executorMeta, opts ...Option) context.Context { ri := &callbacks.RunInfo{} if meta != nil { ri.Component = meta.component ri.Type = meta.componentImplType } if info != nil { ri.Name = info.name } var cbs []callbacks.Handler for i := range opts { if len(opts[i].handler) != 0 { if len(opts[i].paths) != 0 { for _, k := range opts[i].paths { if len(k.path) == 1 && k.path[0] == key { cbs = append(cbs, opts[i].handler...) break } } } } } if len(cbs) == 0 { return icb.ReuseHandlers(ctx, ri) } return icb.AppendHandlers(ctx, ri, cbs...) } func streamChunkConvertForCBOutput[O any](o O) (callbacks.CallbackOutput, error) { return o, nil } func streamChunkConvertForCBInput[I any](i I) (callbacks.CallbackInput, error) { return i, nil } func toAnyList[T any](in []T) []any { ret := make([]any, len(in)) for i := range in { ret[i] = in[i] } return ret } type assignableType uint8 const ( assignableTypeMustNot assignableType = iota assignableTypeMust assignableTypeMay ) func checkAssignable(input, arg reflect.Type) assignableType { if arg == nil || input == nil { return assignableTypeMustNot } if arg == input { return assignableTypeMust } if arg.Kind() == reflect.Interface && input.Implements(arg) { return assignableTypeMust } if input.Kind() == reflect.Interface { if arg.Implements(input) { return assignableTypeMay } return assignableTypeMustNot } return assignableTypeMustNot } func extractOption(nodes map[string]*chanCall, opts ...Option) (map[string][]any, error) { optMap := map[string][]any{} for _, opt := range opts { if len(opt.paths) == 0 { // common, discard callback, filter option by type if len(opt.options) == 0 { continue } for name, c := range nodes { if c.action.optionType == nil { // subgraph optMap[name] = append(optMap[name], opt) } else if reflect.TypeOf(opt.options[0]) == c.action.optionType { // assume that types of options are the same optMap[name] = append(optMap[name], opt.options...) } } } for _, path := range opt.paths { if len(path.path) == 0 { return nil, fmt.Errorf("call option has designated an empty path") } var curNode *chanCall var ok bool if curNode, ok = nodes[path.path[0]]; !ok { return nil, fmt.Errorf("option has designated an unknown node: %s", path) } curNodeKey := path.path[0] if len(path.path) == 1 { if len(opt.options) == 0 { // sub graph common callbacks has been added to ctx in initNodeCallback and won't be passed to subgraph only pass options // node callback also won't be passed continue } if curNode.action.optionType == nil { nOpt := opt.deepCopy() nOpt.paths = []*NodePath{} optMap[curNodeKey] = append(optMap[curNodeKey], nOpt) } else { // designate to component if curNode.action.optionType != reflect.TypeOf(opt.options[0]) { // assume that types of options are the same return nil, fmt.Errorf("option type[%s] is different from which the designated node[%s] expects[%s]", reflect.TypeOf(opt.options[0]).String(), path, curNode.action.optionType.String()) } optMap[curNodeKey] = append(optMap[curNodeKey], opt.options...) } } else { if curNode.action.optionType != nil { // component return nil, fmt.Errorf("cannot designate sub path of a component, path:%s", path) } // designate to sub graph's nodes nOpt := opt.deepCopy() nOpt.paths = []*NodePath{NewNodePath(path.path[1:]...)} optMap[curNodeKey] = append(optMap[curNodeKey], nOpt) } } } return optMap, nil } func mapToList(m map[string]any) []any { ret := make([]any, 0, len(m)) for _, v := range m { ret = append(ret, v) } return ret } ================================================ FILE: compose/utils_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/internal/generic" ) type good interface { ThisIsGood() bool } type good2 interface { ThisIsGood2() bool } type good3 interface { ThisIsGood() bool } type goodImpl struct{} func (g *goodImpl) ThisIsGood() bool { return true } type goodNotImpl struct{} func TestValidateType(t *testing.T) { t.Run("equal_type", func(t *testing.T) { arg := generic.TypeOf[int]() input := generic.TypeOf[int]() result := checkAssignable(input, arg) assert.Equal(t, assignableTypeMust, result) }) t.Run("unequal_type", func(t *testing.T) { arg := generic.TypeOf[int]() input := generic.TypeOf[string]() result := checkAssignable(input, arg) assert.Equal(t, assignableTypeMustNot, result) }) t.Run("implement_interface", func(t *testing.T) { arg := generic.TypeOf[good]() input := generic.TypeOf[*goodImpl]() result := checkAssignable(input, arg) assert.Equal(t, assignableTypeMust, result) }) t.Run("may_implement_interface", func(t *testing.T) { arg := generic.TypeOf[*goodImpl]() input := generic.TypeOf[good]() result := checkAssignable(input, arg) assert.Equal(t, assignableTypeMay, result) }) t.Run("not_implement_interface", func(t *testing.T) { arg := generic.TypeOf[good]() input := generic.TypeOf[*goodNotImpl]() result := checkAssignable(input, arg) assert.Equal(t, assignableTypeMustNot, result) }) t.Run("interface_unequal_interface", func(t *testing.T) { arg := generic.TypeOf[good]() input := generic.TypeOf[good2]() result := checkAssignable(input, arg) assert.Equal(t, assignableTypeMustNot, result) }) t.Run("interface_equal_interface", func(t *testing.T) { arg := generic.TypeOf[good]() input := generic.TypeOf[good3]() result := checkAssignable(input, arg) assert.Equal(t, assignableTypeMust, result) }) } func TestStreamChunkConvert(t *testing.T) { o, err := streamChunkConvertForCBOutput(1) assert.Nil(t, err) assert.Equal(t, o, 1) i, err := streamChunkConvertForCBInput(1) assert.Nil(t, err) assert.Equal(t, i, 1) } ================================================ FILE: compose/values_merge.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "fmt" "reflect" "github.com/cloudwego/eino/internal" ) // RegisterValuesMergeFunc registers a function to merge outputs from multiple nodes when fan-in. // It's used to define how to merge for a specific type. // For maps that already have a default merge function, you don't need to register a new one unless you want to customize the merge logic. func RegisterValuesMergeFunc[T any](fn func([]T) (T, error)) { internal.RegisterValuesMergeFunc(fn) } type mergeOptions struct { streamMergeWithSourceEOF bool names []string } // the caller should ensure len(vs) > 1 func mergeValues(vs []any, opts *mergeOptions) (any, error) { v0 := reflect.ValueOf(vs[0]) t0 := v0.Type() if fn := internal.GetMergeFunc(t0); fn != nil { return fn(vs) } // merge StreamReaders if s, ok := vs[0].(streamReader); ok { t := s.getChunkType() if internal.GetMergeFunc(t) == nil { return nil, fmt.Errorf("(mergeValues | stream type)"+ " unsupported chunk type: %v", t) } ss := make([]streamReader, len(vs)-1) for i := 0; i < len(ss); i++ { sri, ok_ := vs[i+1].(streamReader) if !ok_ { return nil, fmt.Errorf("(mergeStream) unexpected type. "+ "expect: %v, got: %v", t0, reflect.TypeOf(vs[i])) } if st := sri.getChunkType(); st != t { return nil, fmt.Errorf("(mergeStream) chunk type mismatch. "+ "expect: %v, got: %v", t, st) } ss[i] = sri } if opts != nil && opts.streamMergeWithSourceEOF { ms := s.mergeWithNames(ss, opts.names) return ms, nil } ms := s.merge(ss) return ms, nil } return nil, fmt.Errorf("(mergeValues) unsupported type: %v", t0) } ================================================ FILE: compose/values_merge_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "fmt" "io" "sort" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudwego/eino/schema" ) func Test_mergeValues(t *testing.T) { t.Run("merge maps", func(t *testing.T) { m1 := map[int]int{1: 1, 2: 2, 3: 3, 4: 4} m2 := map[int]int{5: 5, 6: 6, 7: 7, 8: 8} m3 := map[int]int{9: 9, 10: 10, 11: 11} t.Run("regular", func(t *testing.T) { mergedM, err := mergeValues([]any{m1, m2, m3}, nil) assert.NoError(t, err) m := mergedM.(map[int]int) // len(m) == len(m1) + len(m2) + len(m3) assert.Equal(t, len(m), len(m1)+len(m2)+len(m3)) }) t.Run("duplicated key", func(t *testing.T) { _, err := mergeValues([]any{m1, m2, m3, map[int]int{1: 1}}, nil) assert.ErrorContains(t, err, "duplicated key") }) t.Run("type mismatch", func(t *testing.T) { _, err := mergeValues([]any{m1, m2, m3, map[int]string{1: "1"}}, nil) assert.ErrorContains(t, err, "type mismatch") }) }) t.Run("merge stream", func(t *testing.T) { ass := []any{ packStreamReader(schema.StreamReaderFromArray[map[int]string]([]map[int]string{{1: "1"}})), packStreamReader(schema.StreamReaderFromArray[map[int]string]([]map[int]string{{2: "2"}})), packStreamReader(schema.StreamReaderFromArray[map[int]string]([]map[int]string{{3: "3", 4: "4"}})), } isr, err := mergeValues(ass, nil) require.NoError(t, err) ret, ok := unpackStreamReader[map[int]string](isr.(streamReader)) require.True(t, ok) defer ret.Close() got := make(map[int]string) for i := 0; i < 3; i++ { m, err := ret.Recv() require.NoError(t, err) for k, v := range m { got[k] = v } } _, err = ret.Recv() require.ErrorIs(t, err, io.EOF) assert.Equal(t, map[int]string{ 1: "1", 2: "2", 3: "3", 4: "4", }, got) }) t.Run("merge stream with source EOF", func(t *testing.T) { ass := []any{ packStreamReader(schema.StreamReaderFromArray[map[int]string]([]map[int]string{{1: "1"}})), packStreamReader(schema.StreamReaderFromArray[map[int]string]([]map[int]string{{2: "2"}})), packStreamReader(schema.StreamReaderFromArray[map[int]string]([]map[int]string{{3: "3", 4: "4"}})), } opts := &mergeOptions{ streamMergeWithSourceEOF: true, names: []string{ "source0", "source1", "source2", }, } isr, err := mergeValues(ass, opts) require.NoError(t, err) ret, ok := unpackStreamReader[map[int]string](isr.(streamReader)) require.True(t, ok) defer ret.Close() got := make(map[int]string) endedSources := make(map[string]bool) for { m, e := ret.Recv() if e != nil { if sourceName, ok_ := schema.GetSourceName(e); ok_ { t.Logf("Source '%s' ended", sourceName) endedSources[sourceName] = true continue } if e == io.EOF { // This EOF means all chunks from all sources that were not SourceEOF have been merged and sent. // Or, if all sources send SourceEOF first, this io.EOF means the merged stream itself is now empty. break } require.NoError(t, e) // Fail on any other error } // If streamMergeWithSourceEOF is true, the final merged result comes as a single map chunk // after all SourceEOFs (if any non-empty streams existed) or directly if all streams were empty. for k, v := range m { got[k] = v } } // Check that all expected sources have ended if they were part of opts.names for i := 0; i < len(ass); i++ { expectedSourceName := opts.names[i] assert.True(t, endedSources[expectedSourceName], "Expected source %s to have sent SourceEOF", expectedSourceName) } // The final 'got' map should contain all items because streamMergeWithSourceEOF merges them at the end. assert.Equal(t, map[int]string{ 1: "1", 2: "2", 3: "3", 4: "4", }, got) }) type TestType struct { A int B []string } RegisterValuesMergeFunc(func(vs []*TestType) (*TestType, error) { ret := &TestType{} for _, v := range vs { if v == nil { continue } if ret.A < 0 { return nil, fmt.Errorf("test error: %v", ret.A) } ret.A += v.A ret.B = append(ret.B, v.B...) } sort.Strings(ret.B) return ret, nil }) t.Run("custom merge", func(t *testing.T) { t.Run("regular", func(t *testing.T) { vs := []any{ &TestType{A: 0, B: []string{}}, &TestType{A: 1, B: []string{"1"}}, &TestType{A: 2, B: []string{"2", "22"}}, &TestType{A: 3, B: []string{"3", "33", "333"}}, } ret, err := mergeValues(vs, nil) require.NoError(t, err) assert.Equal(t, &TestType{ A: 6, B: []string{"1", "2", "22", "3", "33", "333"}, }, ret) }) t.Run("custom error", func(t *testing.T) { vs := []any{ &TestType{A: 0, B: []string{}}, &TestType{A: 1, B: []string{"1"}}, &TestType{A: -2, B: []string{"2", "22"}}, &TestType{A: 3, B: []string{"3", "33", "333"}}, } _, err := mergeValues(vs, nil) require.ErrorContains(t, err, "test error") }) t.Run("type mismatch", func(t *testing.T) { vs := []any{ &TestType{A: 0, B: []string{}}, &TestType{A: 1, B: []string{"1"}}, &TestType{A: 2, B: []string{"2", "22"}}, "test3", } _, err := mergeValues(vs, nil) require.ErrorContains(t, err, "type mismatch") }) t.Run("stream", func(t *testing.T) { ass := []any{ packStreamReader(schema.StreamReaderFromArray([]*TestType{ {A: 0, B: []string{}}, })), packStreamReader(schema.StreamReaderFromArray([]*TestType{ {A: 1, B: []string{"1"}}, })), packStreamReader(schema.StreamReaderFromArray([]*TestType{ {A: 2, B: []string{"2", "22"}}, })), packStreamReader(schema.StreamReaderFromArray([]*TestType{ {A: 3, B: []string{"3", "33", "333"}}, })), } isr, err := mergeValues(ass, nil) require.NoError(t, err) ret, ok := unpackStreamReader[*TestType](isr.(streamReader)) require.True(t, ok) defer ret.Close() var vs []any for i := 0; i < 4; i++ { v, err := ret.Recv() require.NoError(t, err) vs = append(vs, v) } _, err = ret.Recv() require.ErrorIs(t, err, io.EOF) merged, err := mergeValues(vs, nil) require.NoError(t, err) assert.Equal(t, &TestType{ A: 6, B: []string{"1", "2", "22", "3", "33", "333"}, }, merged) }) }) t.Run("unregistered type", func(t *testing.T) { type Unregistered TestType _, err := mergeValues([]any{&Unregistered{}}, nil) assert.ErrorContains(t, err, "unsupported type") }) } ================================================ FILE: compose/workflow.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "fmt" "reflect" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" ) // WorkflowNode is the node of the Workflow. type WorkflowNode struct { g *graph key string addInputs []func() error staticValues map[string]any dependencySetter func(fromNodeKey string, typ dependencyType) mappedFieldPath map[string]any } // Workflow is wrapper of graph, replacing AddEdge with declaring dependencies and field mappings between nodes. // Under the hood it uses NodeTriggerMode(AllPredecessor), so does not support cycles. type Workflow[I, O any] struct { g *graph workflowNodes map[string]*WorkflowNode workflowBranches []*WorkflowBranch dependencies map[string]map[string]dependencyType } type dependencyType int const ( normalDependency dependencyType = iota noDirectDependency branchDependency ) // NewWorkflow creates a new Workflow. func NewWorkflow[I, O any](opts ...NewGraphOption) *Workflow[I, O] { options := &newGraphOptions{} for _, opt := range opts { opt(options) } wf := &Workflow[I, O]{ g: newGraphFromGeneric[I, O]( ComponentOfWorkflow, options.withState, options.stateType, opts, ), workflowNodes: make(map[string]*WorkflowNode), dependencies: make(map[string]map[string]dependencyType), } return wf } // Compile builds the workflow into a runnable graph. func (wf *Workflow[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { return compileAnyGraph[I, O](ctx, wf, opts...) } // AddChatModelNode adds a chat model node and returns it. func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.BaseChatModel, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddChatModelNode(key, chatModel, opts...) return wf.initNode(key) } // AddChatTemplateNode adds a chat template node and returns it. func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddChatTemplateNode(key, chatTemplate, opts...) return wf.initNode(key) } // AddToolsNode adds a tools node and returns it. func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddToolsNode(key, tools, opts...) return wf.initNode(key) } // AddRetrieverNode adds a retriever node and returns it. func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddRetrieverNode(key, retriever, opts...) return wf.initNode(key) } // AddEmbeddingNode adds an embedding node and returns it. func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddEmbeddingNode(key, embedding, opts...) return wf.initNode(key) } // AddIndexerNode adds an indexer node to the workflow and returns it. func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddIndexerNode(key, indexer, opts...) return wf.initNode(key) } // AddLoaderNode adds a document loader node to the workflow and returns it. func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddLoaderNode(key, loader, opts...) return wf.initNode(key) } // AddDocumentTransformerNode adds a document transformer node and returns it. func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer document.Transformer, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddDocumentTransformerNode(key, transformer, opts...) return wf.initNode(key) } // AddGraphNode adds a nested graph node to the workflow and returns it. func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddGraphNode(key, graph, opts...) return wf.initNode(key) } // AddLambdaNode adds a lambda node to the workflow and returns it. func (wf *Workflow[I, O]) AddLambdaNode(key string, lambda *Lambda, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddLambdaNode(key, lambda, opts...) return wf.initNode(key) } // End returns the WorkflowNode representing END node. func (wf *Workflow[I, O]) End() *WorkflowNode { if node, ok := wf.workflowNodes[END]; ok { return node } return wf.initNode(END) } // AddPassthroughNode adds a passthrough node to the workflow and returns it. func (wf *Workflow[I, O]) AddPassthroughNode(key string, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddPassthroughNode(key, opts...) return wf.initNode(key) } // AddInput creates both data and execution dependencies between nodes. // It configures how data flows from the predecessor node (fromNodeKey) to the current node, // and ensures the current node only executes after the predecessor completes. // // Parameters: // - fromNodeKey: the key of the predecessor node // - inputs: field mappings that specify how data should flow from the predecessor // to the current node. If no mappings are provided, the entire output of the // predecessor will be used as input. // // Example: // // // Map between specific field // node.AddInput("userNode", MapFields("user.name", "displayName")) // // // Use entire output // node.AddInput("dataNode") // // Returns the current node for method chaining. func (n *WorkflowNode) AddInput(fromNodeKey string, inputs ...*FieldMapping) *WorkflowNode { return n.addDependencyRelation(fromNodeKey, inputs, &workflowAddInputOpts{}) } type workflowAddInputOpts struct { // noDirectDependency indicates whether to create a data mapping without establishing // a direct execution dependency. When true, the current node can access data from // the predecessor node but its execution is not directly blocked by it. noDirectDependency bool // dependencyWithoutInput indicates whether to create an execution dependency // without any data mapping. When true, the current node will wait for the // predecessor node to complete but won't receive any data from it. dependencyWithoutInput bool } // WorkflowAddInputOpt configures behavior of AddInputWithOptions. type WorkflowAddInputOpt func(*workflowAddInputOpts) func getAddInputOpts(opts []WorkflowAddInputOpt) *workflowAddInputOpts { opt := &workflowAddInputOpts{} for _, o := range opts { o(opt) } return opt } // WithNoDirectDependency creates a data mapping without establishing a direct execution dependency. // The predecessor node will still complete before the current node executes, but through indirect // execution paths rather than a direct dependency. // // In a workflow graph, node dependencies typically serve two purposes: // 1. Execution order: determining when nodes should run // 2. Data flow: specifying how data passes between nodes // // This option separates these concerns by: // - Creating data mapping from the predecessor to the current node // - Relying on the predecessor's path to reach the current node through other nodes // that have direct execution dependencies // // Example: // // node.AddInputWithOptions("dataNode", mappings, WithNoDirectDependency()) // // Important: // // 1. Branch scenarios: When connecting nodes on different sides of a branch, // WithNoDirectDependency MUST be used to let the branch itself handle the // execution order, preventing incorrect dependencies that could bypass the branch. // // 2. Execution guarantee: The predecessor will still complete before the current // node executes because the predecessor must have a path (through other nodes) // that eventually reaches the current node. // // 3. Graph validity: There MUST be a path from the predecessor that eventually // reaches the current node through other nodes with direct dependencies. // This ensures the execution order while avoiding redundant direct dependencies. // // Common use cases: // - Cross-branch data access where the branch handles execution order // - Avoiding redundant dependencies when a path already exists func WithNoDirectDependency() WorkflowAddInputOpt { return func(opt *workflowAddInputOpts) { opt.noDirectDependency = true } } // AddInputWithOptions creates a dependency between nodes with custom configuration options. // It allows fine-grained control over both data flow and execution dependencies. // // Parameters: // - fromNodeKey: the key of the predecessor node // - inputs: field mappings that specify how data flows from the predecessor to the current node. // If no mappings are provided, the entire output of the predecessor will be used as input. // - opts: configuration options that control how the dependency is established // // Example: // // // Create data mapping without direct execution dependency // node.AddInputWithOptions("dataNode", mappings, WithNoDirectDependency()) // // Returns the current node for method chaining. func (n *WorkflowNode) AddInputWithOptions(fromNodeKey string, inputs []*FieldMapping, opts ...WorkflowAddInputOpt) *WorkflowNode { return n.addDependencyRelation(fromNodeKey, inputs, getAddInputOpts(opts)) } // AddDependency creates an execution-only dependency between nodes. // The current node will wait for the predecessor node to complete before executing, // but no data will be passed between them. // // Parameters: // - fromNodeKey: the key of the predecessor node that must complete before this node starts // // Example: // // // Wait for "setupNode" to complete before executing // node.AddDependency("setupNode") // // This is useful when: // - You need to ensure execution order without data transfer // - The predecessor performs setup or initialization that must complete first // - You want to explicitly separate execution dependencies from data flow // // Returns the current node for method chaining. func (n *WorkflowNode) AddDependency(fromNodeKey string) *WorkflowNode { return n.addDependencyRelation(fromNodeKey, nil, &workflowAddInputOpts{dependencyWithoutInput: true}) } // SetStaticValue sets a static value for a field path that will be available // during workflow execution. These values are determined at compile time and // remain constant throughout the workflow's lifecycle. // // Example: // // node.SetStaticValue(FieldPath{"query"}, "static query") func (n *WorkflowNode) SetStaticValue(path FieldPath, value any) *WorkflowNode { n.staticValues[path.join()] = value return n } func (n *WorkflowNode) addDependencyRelation(fromNodeKey string, inputs []*FieldMapping, options *workflowAddInputOpts) *WorkflowNode { for _, input := range inputs { input.fromNodeKey = fromNodeKey } if options.noDirectDependency { n.addInputs = append(n.addInputs, func() error { var paths []FieldPath for _, input := range inputs { paths = append(paths, input.targetPath()) } if err := n.checkAndAddMappedPath(paths); err != nil { return err } if err := n.g.addEdgeWithMappings(fromNodeKey, n.key, true, false, inputs...); err != nil { return err } n.dependencySetter(fromNodeKey, noDirectDependency) return nil }) } else if options.dependencyWithoutInput { n.addInputs = append(n.addInputs, func() error { if len(inputs) > 0 { return fmt.Errorf("dependency without input should not have inputs. node: %s, fromNode: %s, inputs: %v", n.key, fromNodeKey, inputs) } if err := n.g.addEdgeWithMappings(fromNodeKey, n.key, false, true); err != nil { return err } n.dependencySetter(fromNodeKey, normalDependency) return nil }) } else { n.addInputs = append(n.addInputs, func() error { var paths []FieldPath for _, input := range inputs { paths = append(paths, input.targetPath()) } if err := n.checkAndAddMappedPath(paths); err != nil { return err } if err := n.g.addEdgeWithMappings(fromNodeKey, n.key, false, false, inputs...); err != nil { return err } n.dependencySetter(fromNodeKey, normalDependency) return nil }) } return n } func (n *WorkflowNode) checkAndAddMappedPath(paths []FieldPath) error { if v, ok := n.mappedFieldPath[""]; ok { if _, ok = v.(struct{}); ok { return fmt.Errorf("entire output has already been mapped for node: %s", n.key) } } else { if len(paths) == 0 { n.mappedFieldPath[""] = struct{}{} return nil } else { n.mappedFieldPath[""] = map[string]any{} } } for _, targetPath := range paths { m := n.mappedFieldPath[""].(map[string]any) var traversed FieldPath for i, path := range targetPath { traversed = append(traversed, path) if v, ok := m[path]; ok { if _, ok = v.(struct{}); ok { return fmt.Errorf("two terminal field paths conflict for node %s: %v, %v", n.key, traversed, targetPath) } } if i < len(targetPath)-1 { m[path] = make(map[string]any) m = m[path].(map[string]any) } else { m[path] = struct{}{} } } } return nil } // WorkflowBranch represents a branch added to a workflow. // Each branch may define its own end nodes and mappings. type WorkflowBranch struct { fromNodeKey string *GraphBranch } // AddBranch adds a branch to the workflow. // // End Nodes Field Mappings: // End nodes of the branch are required to define their own field mappings. // This is a key distinction between Graph's Branch and Workflow's Branch: // - Graph's Branch: Automatically passes its input to the selected node. // - Workflow's Branch: Does not pass its input to the selected node. func (wf *Workflow[I, O]) AddBranch(fromNodeKey string, branch *GraphBranch) *WorkflowBranch { wb := &WorkflowBranch{ fromNodeKey: fromNodeKey, GraphBranch: branch, } wf.workflowBranches = append(wf.workflowBranches, wb) return wb } // AddEnd connects a node to END with optional field mappings. // Deprecated: use *Workflow[I,O].End() to obtain a WorkflowNode instance for END, then work with it just like a normal WorkflowNode. func (wf *Workflow[I, O]) AddEnd(fromNodeKey string, inputs ...*FieldMapping) *Workflow[I, O] { for _, input := range inputs { input.fromNodeKey = fromNodeKey } _ = wf.g.addEdgeWithMappings(fromNodeKey, END, false, false, inputs...) return wf } func (wf *Workflow[I, O]) compile(ctx context.Context, options *graphCompileOptions) (*composableRunnable, error) { if wf.g.buildError != nil { return nil, wf.g.buildError } for _, wb := range wf.workflowBranches { for endNode := range wb.endNodes { if endNode == END { if _, ok := wf.dependencies[END]; !ok { wf.dependencies[END] = make(map[string]dependencyType) } wf.dependencies[END][wb.fromNodeKey] = branchDependency } else { n := wf.workflowNodes[endNode] n.dependencySetter(wb.fromNodeKey, branchDependency) } } _ = wf.g.addBranch(wb.fromNodeKey, wb.GraphBranch, true) } for _, n := range wf.workflowNodes { for _, addInput := range n.addInputs { if err := addInput(); err != nil { return nil, err } } n.addInputs = nil } for _, n := range wf.workflowNodes { if len(n.staticValues) > 0 { value := make(map[string]any, len(n.staticValues)) var paths []FieldPath for path, v := range n.staticValues { value[path] = v paths = append(paths, splitFieldPath(path)) } if err := n.checkAndAddMappedPath(paths); err != nil { return nil, err } pair := handlerPair{ invoke: func(in any) (any, error) { values := []any{in, value} return mergeValues(values, nil) }, transform: func(in streamReader) streamReader { sr := schema.StreamReaderFromArray([]map[string]any{value}) newS, err := mergeValues([]any{in, packStreamReader(sr)}, nil) if err != nil { errSR, errSW := schema.Pipe[map[string]any](1) errSW.Send(nil, err) errSW.Close() return packStreamReader(errSR) } return newS.(streamReader) }, } for i := range paths { wf.g.fieldMappingRecords[n.key] = append(wf.g.fieldMappingRecords[n.key], ToFieldPath(paths[i])) } wf.g.handlerPreNode[n.key] = []handlerPair{pair} } } // TODO: check indirect edges are legal return wf.g.compile(ctx, options) } func (wf *Workflow[I, O]) initNode(key string) *WorkflowNode { n := &WorkflowNode{ g: wf.g, key: key, staticValues: make(map[string]any), dependencySetter: func(fromNodeKey string, typ dependencyType) { if _, ok := wf.dependencies[key]; !ok { wf.dependencies[key] = make(map[string]dependencyType) } wf.dependencies[key][fromNodeKey] = typ }, mappedFieldPath: make(map[string]any), } wf.workflowNodes[key] = n return n } func (wf *Workflow[I, O]) getGenericHelper() *genericHelper { return wf.g.getGenericHelper() } func (wf *Workflow[I, O]) inputType() reflect.Type { return wf.g.inputType() } func (wf *Workflow[I, O]) outputType() reflect.Type { return wf.g.outputType() } func (wf *Workflow[I, O]) component() component { return wf.g.component() } ================================================ FILE: compose/workflow_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package compose import ( "context" "errors" "fmt" "io" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/internal/mock/components/embedding" "github.com/cloudwego/eino/internal/mock/components/indexer" "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) func TestWorkflow(t *testing.T) { ctx := context.Background() type structA struct { Field1 string Field2 int Field3 []any } type structB struct { Field1 string Field2 int } type structC struct { Field1 string } type structE struct { Field1 string Field2 string Field3 []any } type structF struct { Field1 string Field2 string Field3 []any B int StateTemp string } RegisterStreamChunkConcatFunc(func(ts []*structF) (*structF, error) { ret := &structF{} for _, tt := range ts { ret.Field1 += tt.Field1 ret.Field2 += tt.Field2 ret.Field3 = append(ret.Field3, tt.Field3...) ret.B += tt.B ret.StateTemp += tt.StateTemp } return ret, nil }) type state struct { temp string } type structEnd struct { Field1 string } subGraph := NewGraph[string, *structB]() _ = subGraph.AddLambdaNode( "1", InvokableLambda(func(ctx context.Context, input string) (*structB, error) { return &structB{Field1: input, Field2: 33}, nil }), ) _ = subGraph.AddEdge(START, "1") _ = subGraph.AddEdge("1", END) subChain := NewChain[any, *structC](). AppendLambda(InvokableLambda(func(_ context.Context, in any) (*structC, error) { return &structC{Field1: fmt.Sprintf("%d", in)}, nil })) type struct2 struct { F map[string]any } subWorkflow := NewWorkflow[[]any, []any]() subWorkflow.AddLambdaNode( "1", InvokableLambda(func(_ context.Context, in []any) ([]any, error) { return in, nil }), WithOutputKey("key")). AddInput(START) // []any -> map["key"][]any subWorkflow.AddLambdaNode( "2", InvokableLambda(func(_ context.Context, in []any) ([]any, error) { return in, nil }), WithInputKey("key"), WithOutputKey("key1")). AddInput("1") // map["key"][]any -> []any -> map["key1"][]any subWorkflow.AddLambdaNode( "3", InvokableLambda(func(_ context.Context, in struct2) (map[string]any, error) { return in.F, nil }), ). AddInput("2", ToField("F")) // map["key1"][]any -> map["F"]map["key1"][]any -> struct2{F: map["key1"]any} -> map["key1"][]any subWorkflow.AddLambdaNode( "4", InvokableLambda(func(_ context.Context, in []any) ([]any, error) { return in, nil }), WithInputKey("key1"), ). AddInput("3") // map["key1"][]any -> []any subWorkflow.End().AddInput("4") w := NewWorkflow[*structA, *structEnd](WithGenLocalState(func(context.Context) *state { return &state{} })) w. AddGraphNode("B", subGraph, WithStatePostHandler(func(ctx context.Context, out *structB, state *state) (*structB, error) { state.temp = out.Field1 return out, nil })). AddInput(START, FromField("Field1")) w. AddGraphNode("C", subChain). AddInput(START, FromField("Field2")) w. AddGraphNode("D", subWorkflow). AddInput(START, FromField("Field3")) w. AddLambdaNode( "E", TransformableLambda(func(_ context.Context, in *schema.StreamReader[structE]) (*schema.StreamReader[structE], error) { return schema.StreamReaderWithConvert(in, func(in structE) (structE, error) { if len(in.Field1) > 0 { in.Field1 = "E:" + in.Field1 } if len(in.Field2) > 0 { in.Field2 = "E:" + in.Field2 } return in, nil }), nil }), WithStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { temp := state.temp return schema.StreamReaderWithConvert(in, func(v structE) (structE, error) { if len(v.Field3) > 0 { v.Field3 = append(v.Field3, "Pre:"+temp) } return v, nil }), nil }), WithStreamStatePostHandler(func(ctx context.Context, out *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { return schema.StreamReaderWithConvert(out, func(v structE) (structE, error) { if len(v.Field1) > 0 { v.Field1 = v.Field1 + "+Post" } return v, nil }), nil })). AddInput("B", MapFields("Field1", "Field1")). AddInput("C", MapFields("Field1", "Field2")). AddInput("D", ToField("Field3")) w. AddLambdaNode( "F", InvokableLambda(func(ctx context.Context, in *structF) (string, error) { return fmt.Sprintf("%v_%v_%v_%v_%v", in.Field1, in.Field2, in.Field3, in.B, in.StateTemp), nil }), WithStatePreHandler(func(ctx context.Context, in *structF, state *state) (*structF, error) { in.StateTemp = state.temp return in, nil }), ). AddInput("B", MapFields("Field2", "B")). AddInput("E", MapFields("Field1", "Field1"), MapFields("Field2", "Field2"), MapFields("Field3", "Field3"), ) w.End().AddInput("F", ToField("Field1")) compiled, err := w.Compile(ctx) assert.NoError(t, err) input := &structA{ Field1: "1", Field2: 2, Field3: []any{ 1, "good", }, } out, err := compiled.Invoke(ctx, input) assert.NoError(t, err) assert.Equal(t, &structEnd{"E:1+Post_E:2_[1 good Pre:1]_33_1"}, out) outStream, err := compiled.Stream(ctx, input) assert.NoError(t, err) defer outStream.Close() for { chunk, err := outStream.Recv() if err != nil { if err == io.EOF { break } t.Error(err) return } assert.Equal(t, &structEnd{"E:1+Post_E:2_[1 good Pre:1]_33_1"}, chunk) } } func TestWorkflowWithMap(t *testing.T) { ctx := context.Background() type structA struct { F1 any } wf := NewWorkflow[map[string]any, map[string]any]() wf.AddLambdaNode("lambda1", InvokableLambda(func(ctx context.Context, in map[string]any) (map[string]any, error) { return in, nil })).AddInput(START, MapFields("map_key", "lambda1_key")) wf.AddLambdaNode("lambda2", InvokableLambda(func(ctx context.Context, in *structA) (*structA, error) { return in, nil })).AddInput(START, MapFields("map_key", "F1")) wf.End().AddInput("lambda1", MapFields("lambda1_key", "end_lambda1")) wf.End().AddInput("lambda2", MapFields("F1", "end_lambda2")) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, map[string]any{"map_key": "value"}) assert.NoError(t, err) assert.Equal(t, map[string]any{"end_lambda1": "value", "end_lambda2": "value"}, out) } func TestWorkflowWithNestedFieldMappings(t *testing.T) { ctx := context.Background() type structA struct { F1 string } type structB struct { F1 *structA F2 map[string]any F3 int F4 any F5 map[string]structA F6 structA } t.Run("from struct.struct.field", func(t *testing.T) { wf := NewWorkflow[*structB, string]() wf.End().AddInput(START, FromFieldPath([]string{"F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, &structB{ F1: &structA{ F1: "hello", }, }) assert.NoError(t, err) assert.Equal(t, "hello", out) wf = NewWorkflow[*structB, string]() wf.End().AddInput(START, FromFieldPath([]string{"F1", "F2"})) _, err = wf.Compile(ctx) assert.ErrorContains(t, err, "has no field[F2]") }) t.Run("to struct.(non-ptr)struct.field", func(t *testing.T) { wf := NewWorkflow[string, *structB]() wf.End().AddInput(START, ToFieldPath([]string{"F6", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, &structB{ F6: structA{ F1: "hello", }, }, out) }) t.Run("to map.(non-ptr)struct.field", func(t *testing.T) { wf := NewWorkflow[string, map[string]structA]() wf.End().AddInput(START, ToFieldPath([]string{"key", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]structA{ "key": { F1: "hello", }, }, out) }) t.Run("from map.map.field", func(t *testing.T) { wf := NewWorkflow[map[string]map[string]string, string]() wf.End().AddInput(START, FromFieldPath([]string{"F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, map[string]map[string]string{ "F1": { "F1": "hello", }, }) assert.NoError(t, err) assert.Equal(t, "hello", out) _, err = r.Invoke(ctx, map[string]map[string]string{ "F1": { "F2": "hello", }, }) var ie *internalError assert.True(t, errors.As(err, &ie)) var myErr *errMapKeyNotFound assert.True(t, errors.As(ie.origError, &myErr)) }) t.Run("from struct.map.field", func(t *testing.T) { wf := NewWorkflow[*structB, string]() wf.End().AddInput(START, FromFieldPath([]string{"F2", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, &structB{ F2: map[string]any{ "F1": "hello", }, }) assert.NoError(t, err) assert.Equal(t, "hello", out) _, err = r.Invoke(ctx, &structB{ F2: map[string]any{ "F2": "hello", }, }) var ie *internalError assert.True(t, errors.As(err, &ie)) var myErr *errMapKeyNotFound assert.True(t, errors.As(ie.origError, &myErr)) }) t.Run("from map.struct.field", func(t *testing.T) { wf := NewWorkflow[map[string]*structA, string]() wf.End().AddInput(START, FromFieldPath([]string{"F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, map[string]*structA{ "F1": { F1: "hello", }, }) assert.NoError(t, err) assert.Equal(t, "hello", out) wf = NewWorkflow[map[string]*structA, string]() wf.End().AddInput(START, FromFieldPath([]string{"F1", "F2"})) _, err = wf.Compile(ctx) assert.ErrorContains(t, err, "has no field[F2]") }) t.Run("from map[string]any.field", func(t *testing.T) { wf := NewWorkflow[map[string]any, string]() wf.End().AddInput(START, FromFieldPath([]string{"F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, map[string]any{ "F1": &structA{ F1: "hello", }, }) assert.NoError(t, err) assert.Equal(t, "hello", out) out, err = r.Invoke(ctx, map[string]any{ "F1": map[string]any{ "F1": "hello", }, }) assert.NoError(t, err) assert.Equal(t, "hello", out) _, err = r.Invoke(ctx, map[string]any{ "F1": 1, }) var ie *internalError assert.True(t, errors.As(err, &ie)) var myErr *errInterfaceNotValidForFieldMapping assert.True(t, errors.As(ie.origError, &myErr)) }) t.Run("to struct.struct.field", func(t *testing.T) { wf := NewWorkflow[string, *structB]() wf.End().AddInput(START, ToFieldPath([]string{"F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, &structB{ F1: &structA{ F1: "hello", }, }, out) wf = NewWorkflow[string, *structB]() wf.End().AddInput(START, ToFieldPath([]string{"F1", "F2"})) _, err = wf.Compile(ctx) assert.ErrorContains(t, err, "has no field[F2]") }) t.Run("to map.map.field", func(t *testing.T) { wf := NewWorkflow[string, map[string]map[string]string]() wf.End().AddInput(START, ToFieldPath([]string{"F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]map[string]string{ "F1": { "F1": "hello", }, }, out) wf1 := NewWorkflow[string, map[string]map[string]int]() wf1.End().AddInput(START, ToFieldPath([]string{"F1", "F1"})) _, err = wf1.Compile(ctx) assert.ErrorContains(t, err, "field[string]-[int] is absolutely not assignable") }) t.Run("to struct.map.field", func(t *testing.T) { wf := NewWorkflow[string, *structB]() wf.End().AddInput(START, ToFieldPath([]string{"F2", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, &structB{ F2: map[string]any{ "F1": "hello", }, }, out) }) t.Run("to map.struct.struct.field", func(t *testing.T) { wf := NewWorkflow[string, map[string]*structB]() wf.End().AddInput(START, ToFieldPath([]string{"F1", "F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]*structB{ "F1": { F1: &structA{ F1: "hello", }, }, }, out) }) t.Run("to struct.map.struct.field", func(t *testing.T) { wf := NewWorkflow[string, *structB]() wf.End().AddInput(START, ToFieldPath([]string{"F5", "key", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, &structB{ F5: map[string]structA{ "key": { F1: "hello", }, }, }, out) }) t.Run("to map.map.struct(non-ptr).field", func(t *testing.T) { wf := NewWorkflow[string, map[string]map[string]structA]() wf.End().AddInput(START, ToFieldPath([]string{"key1", "key2", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]map[string]structA{ "key1": { "key2": { F1: "hello", }, }, }, out) }) t.Run("to struct.int.field", func(t *testing.T) { wf := NewWorkflow[string, *structB]() wf.End().AddInput(START, ToFieldPath([]string{"F3", "F1", "F1"})) _, err := wf.Compile(ctx) assert.ErrorContains(t, err, "type[int] is not valid") }) t.Run("to struct.any.field", func(t *testing.T) { wf := NewWorkflow[string, *structB]() wf.End().AddInput(START, ToFieldPath([]string{"F4", "F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, &structB{ F4: map[string]any{ "F1": map[string]any{ "F1": "hello", }, }, }, out) }) t.Run("to map.any.any.field", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.End().AddInput(START, ToFieldPath([]string{"Key1", "Key2", "Key3"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]any{ "Key1": map[string]any{ "Key2": map[string]any{ "Key3": "hello", }, }, }, out) }) t.Run("to any", func(t *testing.T) { wf := NewWorkflow[string, any]() wf.End().AddInput(START) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, "hello", out) }) t.Run("to any.field", func(t *testing.T) { wf := NewWorkflow[string, any]() wf.End().AddInput(START, ToFieldPath([]string{"Key1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]any{ "Key1": "hello", }, out) }) t.Run("to interface.field", func(t *testing.T) { wf := NewWorkflow[string, map[string]fmt.Stringer]() wf.End().AddInput(START, ToFieldPath([]string{"Key1", "A"})) _, err := wf.Compile(ctx) assert.ErrorContains(t, err, "static check failed for mapping [from start to Key1\u001FA(field)], "+ "the successor has intermediate interface type fmt.Stringer") }) t.Run("both to map.any, and to map.any.field", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.End().AddInput(START, ToFieldPath([]string{"Key1"}), ToFieldPath([]string{"Key1", "Key2"})) _, err := wf.Compile(ctx) assert.ErrorContains(t, err, "two terminal field paths conflict") }) t.Run("to map.any.any.field1, and to map.any.any.field2", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.End().AddInput(START, ToFieldPath([]string{"Key1", "Key2", "key3"}), ToFieldPath([]string{"Key1", "Key2", "key4"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]any{ "Key1": map[string]any{ "Key2": map[string]any{ "key3": "hello", "key4": "hello", }, }, }, out) }) t.Run("from nested to nested", func(t *testing.T) { wf := NewWorkflow[map[string]any, *structB]() wf.End().AddInput(START, MapFieldPaths([]string{"key1", "key2"}, []string{"F1", "F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, map[string]any{ "key1": map[string]any{ "key2": "hello", }, }) assert.NoError(t, err) assert.Equal(t, &structB{ F1: &structA{ F1: "hello", }, }, out) }) t.Run("from nested to normal", func(t *testing.T) { wf := NewWorkflow[map[string]any, *structA]() wf.End().AddInput(START, MapFieldPaths(FieldPath{"key1", "key2"}, FieldPath{"F1"})) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, map[string]any{ "key1": map[string]any{ "key2": "hello", }, }) assert.NoError(t, err) assert.Equal(t, &structA{ F1: "hello", }, out) }) } func TestWorkflowCompile(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) t.Run("compile without add end", func(t *testing.T) { w := NewWorkflow[*schema.Message, []*schema.Message]() w.AddToolsNode("1", &ToolsNode{}).AddInput(START) _, err := w.Compile(ctx) assert.ErrorContains(t, err, "end node not set") }) t.Run("type mismatch", func(t *testing.T) { w := NewWorkflow[string, string]() w.AddToolsNode("1", &ToolsNode{}).AddInput(START) w.End().AddInput("1") _, err := w.Compile(ctx) assert.ErrorContains(t, err, " mismatch") }) t.Run("predecessor's output not struct/struct ptr/map, mapping has FromField", func(t *testing.T) { w := NewWorkflow[[]*schema.Document, []string]() w.AddIndexerNode("indexer", indexer.NewMockIndexer(ctrl)).AddInput(START, FromField("F1")) w.End().AddInput("indexer") _, err := w.Compile(ctx) assert.ErrorContains(t, err, "predecessor output type should be struct") }) t.Run("successor's input not struct/struct ptr/map, mapping has ToField", func(t *testing.T) { w := NewWorkflow[[]string, [][]float64]() w.AddEmbeddingNode("embedder", embedding.NewMockEmbedder(ctrl)).AddInput(START, ToField("F1")) w.End().AddInput("embedder") _, err := w.Compile(ctx) assert.ErrorContains(t, err, "successor input type should be struct") }) t.Run("map to non existing field in predecessor", func(t *testing.T) { w := NewWorkflow[*schema.Message, []*schema.Message]() w.AddToolsNode("tools_node", &ToolsNode{}).AddInput(START, FromField("non_exist")) w.End().AddInput("tools_node") _, err := w.Compile(ctx) assert.ErrorContains(t, err, "type[schema.Message] has no field[non_exist]") }) t.Run("map to not exported field in successor", func(t *testing.T) { w := NewWorkflow[string, *FieldMapping]() w.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })).AddInput(START) w.End().AddInput("1", ToField("to")) _, err := w.Compile(ctx) assert.ErrorContains(t, err, "has an unexported field[to]") }) t.Run("map from not exported field in predecessor", func(t *testing.T) { w := NewWorkflow[*FieldMapping, string]() w.End().AddInput(START, FromField("from")) _, err := w.Compile(ctx) assert.ErrorContains(t, err, "has an unexported field[from]") }) t.Run("duplicate node key", func(t *testing.T) { w := NewWorkflow[[]*schema.Message, []*schema.Message]() w.AddChatModelNode("1", model.NewMockChatModel(ctrl)).AddInput(START) w.AddToolsNode("1", &ToolsNode{}).AddInput("1") w.End().AddInput("1") _, err := w.Compile(ctx) assert.ErrorContains(t, err, "node '1' already present") }) t.Run("from non-existing node", func(t *testing.T) { w := NewWorkflow[*schema.Message, []*schema.Message]() w.AddToolsNode("1", &ToolsNode{}).AddInput(START) w.End().AddInput("2") _, err := w.Compile(ctx) assert.ErrorContains(t, err, "edge start node '2' needs to be added to graph first") }) t.Run("to map with non-string key type", func(t *testing.T) { w := NewWorkflow[string, map[int]any]() w.End().AddInput(START, ToField("1")) _, err := w.Compile(ctx) assert.ErrorContains(t, err, "type[map[int]interface {}] is not a map with string or string alias key") type stringAlias string w1 := NewWorkflow[string, map[stringAlias]any]() w1.End().AddInput(START, ToField("1")) r, err := w1.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[stringAlias]any{ "1": "hello", }, out) }) t.Run("from map with non-string key type", func(t *testing.T) { w := NewWorkflow[map[int]any, string]() w.End().AddInput(START, FromField("1")) _, err := w.Compile(ctx) assert.ErrorContains(t, err, "type[map[int]interface {}] is not a map with string or string alias key") type stringAlias string w1 := NewWorkflow[map[stringAlias]any, string]() w1.End().AddInput(START, FromField("1")) r, err := w1.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, map[stringAlias]any{ "1": "hello", }) assert.NoError(t, err) assert.Equal(t, "hello", out) }) } func TestFanInToSameDest(t *testing.T) { t.Run("traditional outputKey fan-in with map[string]any", func(t *testing.T) { wf := NewWorkflow[string, []*schema.Message]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in, nil }), WithOutputKey("q1")).AddInput(START) wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in, nil }), WithOutputKey("q2")).AddInput(START) wf.AddChatTemplateNode("prompt", prompt.FromMessages(schema.Jinja2, schema.UserMessage("{{q1}}_{{q2}}"))). AddInput("1", MapFields("q1", "q1")). AddInput("2", MapFields("q2", "q2")) wf.End().AddInput("prompt") c, err := wf.Compile(context.Background()) assert.NoError(t, err) out, err := c.Invoke(context.Background(), "query") assert.NoError(t, err) assert.Equal(t, []*schema.Message{{Role: schema.User, Content: "query_query"}}, out) }) t.Run("fan-in to a field of map", func(t *testing.T) { type dest struct { F map[string]any } type in struct { A string B int } wf := NewWorkflow[in, dest]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in, nil }), WithOutputKey("A")).AddInput(START, FromField("A")) wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in int) (output int, err error) { return in, nil }), WithOutputKey("B")).AddInput(START, FromField("B")) wf.End().AddInput("1", ToField("F")).AddInput("2", ToField("F")) _, err := wf.Compile(context.Background()) assert.ErrorContains(t, err, "two terminal field paths conflict for node end: [F], [F]") }) } func TestIndirectEdge(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + "_" + in, nil })).AddInput(START) wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in map[string]string) (output string, err error) { return in["1"] + "_" + in[START], nil })).AddInput("1", ToField("1")). AddInputWithOptions(START, []*FieldMapping{ToField(START)}, WithNoDirectDependency()) wf.End().AddInput("2", ToField("2")). AddInputWithOptions("1", []*FieldMapping{ToField("1")}, WithNoDirectDependency()) r, err := wf.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), "query") assert.NoError(t, err) assert.Equal(t, map[string]any{"1": "query_query", "2": "query_query_query"}, out) } func TestDependencyWithNoInput(t *testing.T) { t.Run("simple case", func(t *testing.T) { wf := NewWorkflow[string, string]() wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return "useless", nil })).AddInput(START) wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + "_done", nil })).AddDependency("0").AddInputWithOptions(START, nil, WithNoDirectDependency()) wf.End().AddInput("1") r, err := wf.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), "hello") assert.NoError(t, err) assert.Equal(t, "hello_done", out) }) t.Run("simple control flow: [Start] --> [Node '0'] --> [End]", func(t *testing.T) { // [Start] --> [Node "0"] --> [End] wf := NewWorkflow[map[string]any, map[string]any]() wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) { return map[string]any{ "result": "result from node 0", }, nil })).AddDependency(START) wf.End().AddInput("0", ToField("final_result")). AddInputWithOptions(START, []*FieldMapping{ToField("final_from_start")}, WithNoDirectDependency()) r, err := wf.Compile(context.Background()) assert.NoError(t, err) ret, err := r.Invoke(context.Background(), map[string]any{ "input": "hello", }) assert.NoError(t, err) assert.Equal(t, map[string]any{ "final_result": map[string]any{ "result": "result from node 0", }, "final_from_start": map[string]any{ "input": "hello", }, }, ret) sRet, err := r.Stream(context.Background(), map[string]any{ "input": "hello", }) assert.NoError(t, err) ret, err = concatStreamReader(sRet) assert.NoError(t, err) assert.Equal(t, map[string]any{ "final_result": map[string]any{ "result": "result from node 0", }, "final_from_start": map[string]any{ "input": "hello", }, }, ret) }) } func TestStaticValue(t *testing.T) { t.Run("prefill map", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) { return in, nil })). AddInput(START, ToField(START)). SetStaticValue(FieldPath{"prefilled"}, "yo-ho") wf.End().AddInput("0") r, err := wf.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), "hello") assert.NoError(t, err) assert.Equal(t, map[string]any{"prefilled": "yo-ho", START: "hello"}, out) streamOut, err := r.Stream(context.Background(), "hello") assert.NoError(t, err) out = map[string]any{} for { chunk, err := streamOut.Recv() if err == io.EOF { break } assert.NoError(t, err) for k, v := range chunk { out[k] = v } } assert.Equal(t, map[string]any{"prefilled": "yo-ho", START: "hello"}, out) }) t.Run("static value and to-all mapping conflict", func(t *testing.T) { wf := NewWorkflow[map[string]any, map[string]any]() wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) { return in, nil })). AddInput(START). SetStaticValue( FieldPath{"prefilled"}, "yo-ho", ) wf.End().AddInput("0") _, err := wf.Compile(context.Background()) assert.ErrorContains(t, err, "entire output has already been mapped for node: 0") }) t.Run("static value and dynamic mapping conflict", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) { return in, nil })). AddInput(START, ToField("prefilled")). SetStaticValue(FieldPath{"prefilled"}, "yo-ho") wf.End().AddInput("0") _, err := wf.Compile(context.Background()) assert.ErrorContains(t, err, "two terminal field paths conflict for node 0: [prefilled], [prefilled]") }) t.Run("all inputs are static values", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) { return in, nil })). AddDependency(START). SetStaticValue(FieldPath{"a", "b"}, "a_b"). SetStaticValue(FieldPath{"c", "d"}, "c_d"). SetStaticValue(FieldPath{"a", "d"}, "a_d") wf.End().AddInput("0") r, err := wf.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), "hello") assert.NoError(t, err) assert.Equal(t, map[string]any{ "a": map[string]any{ "b": "a_b", "d": "a_d", }, "c": map[string]any{ "d": "c_d", }, }, out) type a struct { B string D string } type s struct { A a C map[string]any } wf1 := NewWorkflow[string, *s]() wf1.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) { return in, nil })). AddDependency(START). SetStaticValue(FieldPath{"A", "B"}, "a_b"). SetStaticValue(FieldPath{"C", "D"}, "c_d"). SetStaticValue(FieldPath{"A", "D"}, "a_d") wf1.End().AddInput("0", MapFieldPaths(FieldPath{"A", "B"}, FieldPath{"A", "B"}), MapFieldPaths(FieldPath{"A", "D"}, FieldPath{"A", "D"}), MapFields("C", "C")) r1, err := wf1.Compile(context.Background()) assert.NoError(t, err) out1, err := r1.Stream(context.Background(), "hello") assert.NoError(t, err) outChunk, err := out1.Recv() out1.Close() assert.Equal(t, &s{ A: a{ B: "a_b", D: "a_d", }, C: map[string]any{ "D": "c_d", }, }, outChunk) }) } func TestBranch(t *testing.T) { ctx := context.Background() t.Run("simple branch: one predecessor, two successor, one of them is END", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + "_" + in, nil })).AddInputWithOptions(START, nil, WithNoDirectDependency()) wf.AddPassthroughNode("branch_1").AddInput(START, ToField(START)) branch := NewGraphBranch(func(ctx context.Context, in map[string]any) (string, error) { if in[START] == "hello" { return "1", nil } return END, nil }, map[string]bool{ "1": true, END: true, }) wf.AddBranch("branch_1", branch) wf.End().AddInput("1", ToField("1")).AddInputWithOptions(START, []*FieldMapping{ToField(START)}, WithNoDirectDependency()) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]any{ "1": "hello_hello", START: "hello", }, out) out, err = r.Invoke(ctx, "world") assert.NoError(t, err) assert.Equal(t, map[string]any{ START: "world", }, out) }) t.Run("multiple predecessors", func(t *testing.T) { wf := NewWorkflow[string, map[string]any]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + "_" + in, nil })).AddInput(START) wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + "_" + in, nil })).AddInputWithOptions("1", nil, WithNoDirectDependency()) wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + "_" + in, nil })).AddInput(START) wf.AddPassthroughNode("branch_1").AddInput(START, ToField(START)).AddInput("1", ToField("1")).AddDependency("0") wf.AddBranch("branch_1", NewGraphBranch(func(ctx context.Context, in map[string]any) (string, error) { if in[START].(string) == "hello" { return "2", nil } return END, nil }, map[string]bool{ "2": true, END: true, })) wf.End().AddInput("2", ToField("2")).AddInputWithOptions(START, []*FieldMapping{ToField(START)}, WithNoDirectDependency()) r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "hello") assert.NoError(t, err) assert.Equal(t, map[string]any{"2": "hello_hello_hello_hello", START: "hello"}, out) out, err = r.Invoke(ctx, "world") assert.NoError(t, err) assert.Equal(t, map[string]any{START: "world"}, out) }) t.Run("empty input for node after branch", func(t *testing.T) { wf := NewWorkflow[map[string]any, map[string]any]() wf.AddLambdaNode("start_1", InvokableLambda(func(ctx context.Context, input map[string]any) (map[string]any, error) { return map[string]any{}, nil })).AddInput("start") wf.AddLambdaNode("branch_1", InvokableLambda(func(ctx context.Context, input map[string]any) (map[string]any, error) { return map[string]any{}, nil })) wf.AddPassthroughNode("my_branch").AddInput("start_1") wf.AddBranch("my_branch", NewGraphBranch(func(ctx context.Context, input map[string]any) (string, error) { return END, nil }, map[string]bool{ "branch_1": true, END: true, })) wf.End().AddInput("branch_1") runner, err := wf.Compile(context.Background()) assert.NoError(t, err) resp, err := runner.Invoke(context.Background(), map[string]any{}) assert.NoError(t, err) assert.Equal(t, resp, (map[string]any)(nil)) }) } type goodInterface interface { GOOD() } type goodStruct struct{} func (g *goodStruct) GOOD() {} func TestMayAssignableFieldMapping(t *testing.T) { type in struct { A goodInterface } wf := NewWorkflow[in, *goodStruct]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input *goodStruct) (output goodInterface, err error) { return input, nil })). AddInput(START, FromField("A")) wf.End().AddInput("1") ctx := context.Background() r, err := wf.Compile(ctx) assert.NoError(t, err) result, err := r.Invoke(ctx, in{A: &goodStruct{}}) assert.NoError(t, err) result.GOOD() } func TestNilValue(t *testing.T) { t.Run("from map key with a nil value to map key", func(t *testing.T) { wf := NewWorkflow[map[string]any, map[string]any]() wf.End().AddInput(START, MapFields("a", "a")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) result, err := r.Invoke(context.Background(), map[string]any{"a": nil}) assert.NoError(t, err) assert.Equal(t, map[string]any{"a": nil}, result) }) t.Run("from nil struct field to map key", func(t *testing.T) { type in struct { A *string } wf := NewWorkflow[in, map[string]any]() wf.End().AddInput(START, MapFields("A", "A")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) result, err := r.Invoke(context.Background(), in{A: nil}) assert.NoError(t, err) assert.Equal(t, map[string]any{"A": (*string)(nil)}, result) }) t.Run("from map key with a nil value to struct field", func(t *testing.T) { type out struct { A *string } wf := NewWorkflow[map[string]any, out]() wf.End().AddInput(START, MapFields("A", "A")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) result, err := r.Invoke(context.Background(), map[string]any{"A": nil}) assert.NoError(t, err) assert.Equal(t, out{A: (*string)(nil)}, result) }) t.Run("from nil struct field to struct field", func(t *testing.T) { type inOut struct { A *string } wf := NewWorkflow[inOut, inOut]() wf.End().AddInput(START, MapFields("A", "A")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) result, err := r.Invoke(context.Background(), inOut{A: nil}) assert.NoError(t, err) assert.Equal(t, inOut{A: (*string)(nil)}, result) }) t.Run("from nil to a type that can't be nil", func(t *testing.T) { wf := NewWorkflow[map[string]any, int]() wf.End().AddInput(START, FromField("a")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), map[string]any{"a": nil}) assert.ErrorContains(t, err, "runtime check failed for mapping [from a(field) of start], field[]-[int] is absolutely not assignable") }) t.Run("from nil to a map other than map[string]any", func(t *testing.T) { wf := NewWorkflow[map[string]any, map[string]fmt.Stringer]() wf.End().AddInput(START, MapFields("a", "a")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), map[string]any{"a": nil}) assert.Equal(t, map[string]fmt.Stringer{ "a": nil, }, out) }) } func TestStreamFieldMap(t *testing.T) { t.Run("multiple incomplete chunks in source stream", func(t *testing.T) { wf := NewWorkflow[map[string]any, map[string]any]() wf.End().AddInput(START, MapFields("a", "a"), MapFields("b", "b")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) sr, sw := schema.Pipe[map[string]any](2) sw.Send(map[string]any{"a": 1}, nil) sw.Send(map[string]any{"b": 2}, nil) sw.Close() outputS, err := r.Transform(context.Background(), sr) assert.NoError(t, err) result, err := concatStreamReader(outputS) assert.NoError(t, err) assert.Equal(t, map[string]any{"a": 1, "b": 2}, result) }) } func TestRuntimeTypeCheck(t *testing.T) { g := NewWorkflow[map[string]any, any]() _ = g. AddLambdaNode("A", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })). AddInput(START, FromField("A")) _ = g.AddLambdaNode("B", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input, nil })). AddInput(START, FromField("B")) _ = g.AddLambdaNode("MergeA", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) { return input, nil })). AddInput("A", ToField("a")). AddInput("B", ToField("b")) g.End().AddInput("MergeA") ctx := context.Background() r, err := g.Compile(ctx) assert.NoError(t, err) result, err := r.Stream(ctx, map[string]any{"A": "1", "B": "2"}) assert.NoError(t, err) chunk, err := result.Recv() assert.NoError(t, err) assert.Equal(t, map[string]any{"a": "1", "b": "2"}, chunk) chunk, err = result.Recv() assert.True(t, errors.Is(err, io.EOF)) } func TestIntermediateMappingSource(t *testing.T) { t.Run("intermediate any source is nil", func(t *testing.T) { wf := NewWorkflow[map[string]any, any]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a", "b"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), map[string]any{ "a": nil, }) assert.ErrorContains(t, err, "intermediate source value on path=[a b] is nil for type [interface {}]") outStream, err := r.Transform(context.Background(), schema.StreamReaderFromArray([]map[string]any{ { "a": nil, }, { "b": "ok", }, })) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "intermediate source value on path=[a b] is nil for type [interface {}]") outStream.Close() }) t.Run("intermediate map source is nil", func(t *testing.T) { wf := NewWorkflow[map[string]any, any]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), nil) assert.ErrorContains(t, err, "intermediate source value on path=[a] is nil for map type [map[string]interface {}]") outStream, err := r.Stream(context.Background(), nil) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "intermediate source value on path=[a] is nil for map type [map[string]interface {}]") outStream.Close() }) t.Run("intermediate map ptr source is nil", func(t *testing.T) { wf := NewWorkflow[*map[string]any, any]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), nil) assert.ErrorContains(t, err, "intermediate source value on path=[a] is nil for type [*map[string]interface {}]") outStream, err := r.Stream(context.Background(), nil) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "intermediate source value on path=[a] is nil for type [*map[string]interface {}]") outStream.Close() }) t.Run("intermediate struct ptr source is nil", func(t *testing.T) { type inner struct { A string } wf := NewWorkflow[map[string]*inner, string]() wf.End().AddInput(START, FromFieldPath(FieldPath{"I", "A"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), map[string]*inner{"I": nil}) assert.ErrorContains(t, err, "intermediate source value on path=[I A] is nil") outStream, err := r.Stream(context.Background(), map[string]*inner{"I": nil}) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "intermediate source value on path=[I A] is nil") outStream.Close() }) t.Run("intermediate interface source is nil", func(t *testing.T) { wf := NewWorkflow[map[string]fmt.Stringer, string]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a", "b"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), map[string]fmt.Stringer{"a": nil}) assert.ErrorContains(t, err, "intermediate source value on path=[a b] is nil for type [fmt.Stringer]") outStream, err := r.Stream(context.Background(), map[string]fmt.Stringer{"a": nil}) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "intermediate source value on path=[a b] is nil for type [fmt.Stringer]") outStream.Close() }) t.Run("intermediate interface source valid", func(t *testing.T) { wf := NewWorkflow[map[string]fmt.Stringer, string]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a", "A"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) out, err := r.Invoke(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello"}}) assert.NoError(t, err) assert.Equal(t, "hello", out) outStream, err := r.Stream(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello"}}) assert.NoError(t, err) out, err = outStream.Recv() assert.NoError(t, err) assert.Equal(t, "hello", out) outStream.Close() }) t.Run("intermediate interface source, source field not found at request time", func(t *testing.T) { wf := NewWorkflow[map[string]fmt.Stringer, string]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a", "B"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello"}}) assert.ErrorContains(t, err, "field mapping from a struct field, but field not found. field=B") outStream, err := r.Stream(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello"}}) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "field mapping from a struct field, but field not found. field=B") outStream.Close() }) t.Run("intermediate interface source, source field not exported at request time", func(t *testing.T) { wf := NewWorkflow[map[string]fmt.Stringer, string]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a", "c"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello", c: "c"}}) assert.ErrorContains(t, err, "field mapping from a struct field, but field not exported.") outStream, err := r.Stream(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello", c: "c"}}) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "field mapping from a struct field, but field not exported.") outStream.Close() }) t.Run("intermediate interface source, type mismatch at request time", func(t *testing.T) { wf := NewWorkflow[map[string]fmt.Stringer, int]() wf.End().AddInput(START, FromFieldPath(FieldPath{"a", "A"})) r, err := wf.Compile(context.Background()) assert.NoError(t, err) _, err = r.Invoke(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello"}}) assert.ErrorContains(t, err, "runtime check failed for mapping [from a\x1fA(field) of start], field[string]-[int] is absolutely not assignable") outStream, err := r.Stream(context.Background(), map[string]fmt.Stringer{"a": &goodStruct2{A: "hello"}}) assert.NoError(t, err) _, err = outStream.Recv() assert.ErrorContains(t, err, "runtime check failed for mapping [from a\u001FA(field) of start], field[string]-[int] is absolutely not assignable") outStream.Close() }) } type goodStruct2 struct { A string c string } func (g *goodStruct2) String() string { return g.A } func TestSetFanInMergeConfig_RealStreamNode_Workflow(t *testing.T) { wf := NewWorkflow[int, map[string]int]() wf.AddLambdaNode("s1", StreamableLambda(func(ctx context.Context, input int) (*schema.StreamReader[int], error) { sr, sw := schema.Pipe[int](2) sw.Send(input+1, nil) sw.Send(input+2, nil) sw.Close() return sr, nil })).AddInput(START) wf.AddLambdaNode("s2", StreamableLambda(func(ctx context.Context, input int) (*schema.StreamReader[int], error) { sr, sw := schema.Pipe[int](2) sw.Send(input+10, nil) sw.Send(input+20, nil) sw.Close() return sr, nil })).AddInput(START) wf.End().AddInput("s1", ToField("s1")).AddInput("s2", ToField("s2")) r, err := wf.Compile(context.Background(), WithFanInMergeConfig(map[string]FanInMergeConfig{END: {StreamMergeWithSourceEOF: true}})) assert.NoError(t, err) sr, err := r.Stream(context.Background(), 1) assert.NoError(t, err) merged := make(map[string]map[int]bool) var sourceEOFCount int sourceNames := make(map[string]bool) for { m, e := sr.Recv() if e != nil { if name, ok := schema.GetSourceName(e); ok { sourceEOFCount++ sourceNames[name] = true continue } if e == io.EOF { break } assert.NoError(t, e) } for k, v := range m { if merged[k] == nil { merged[k] = make(map[int]bool) } merged[k][v] = true } } assert.Equal(t, map[string]map[int]bool{"s1": {2: true, 3: true}, "s2": {11: true, 21: true}}, merged) assert.Equal(t, 2, sourceEOFCount, "should receive SourceEOF for each input stream when StreamMergeWithSourceEOF is true") assert.True(t, sourceNames["s1"], "should receive SourceEOF from s1") assert.True(t, sourceNames["s2"], "should receive SourceEOF from s2") } func TestCustomExtractor(t *testing.T) { t.Run("custom extract from array element", func(t *testing.T) { wf := NewWorkflow[[]int, map[string]int]() wf.End().AddInput(START, ToField("a", WithCustomExtractor(func(input any) (any, error) { return input.([]int)[0], nil }))) r, err := wf.Compile(context.Background()) assert.NoError(t, err) result, err := r.Invoke(context.Background(), []int{1, 2}) assert.NoError(t, err) assert.Equal(t, map[string]int{"a": 1}, result) }) t.Run("mix custom extract with normal mapping", func(t *testing.T) { wf := NewWorkflow[map[string]any, map[string]int]() wf.End().AddInput(START, ToField("a", WithCustomExtractor(func(input any) (any, error) { return input.(map[string]any)["a"].([]any)[0].(map[string]any)["c"], nil })), MapFields("b", "b")) r, err := wf.Compile(context.Background()) assert.NoError(t, err) result, err := r.Invoke(context.Background(), map[string]any{ "a": []any{ map[string]any{ "c": 1, }, }, "b": 2, }) assert.NoError(t, err) assert.Equal(t, map[string]int{"a": 1, "b": 2}, result) }) } func TestAddDependency(t *testing.T) { ctx := context.Background() wf := NewWorkflow[string, any]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { return in + "_" + in, nil })).AddDependency(START) wf.End().AddDependency("1") r, err := wf.Compile(ctx) assert.NoError(t, err) out, err := r.Invoke(ctx, "input") assert.NoError(t, err) assert.Equal(t, nil, out) } func TestIndirectDependencyWithBranch(t *testing.T) { t.Run("data only mapping across branch", func(t *testing.T) { wf := NewWorkflow[[]int, map[string]any]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input int) (output int, err error) { return input + 1, nil })). AddInputWithOptions(START, []*FieldMapping{ToField("", WithCustomExtractor(func(input any) (any, error) { inputList := input.([]int) if len(inputList) == 0 { return nil, fmt.Errorf("input list is empty") } return input.([]int)[0], nil }))}, WithNoDirectDependency()) wf.AddBranch(START, NewGraphBranch(func(ctx context.Context, in []int) (endNode string, err error) { if len(in) > 0 { return "1", nil } return END, nil }, map[string]bool{"1": true, END: true})) wf.End(). AddInput("1", ToField("output")). SetStaticValue(FieldPath{"static"}, 2) r, err := wf.Compile(context.Background()) assert.NoError(t, err) // skip lambda node "1" out, err := r.Invoke(context.Background(), nil) assert.NoError(t, err) assert.Equal(t, out, map[string]any{"static": 2}) // choose lambda node "1" out, err = r.Invoke(context.Background(), []int{1}) assert.NoError(t, err) assert.Equal(t, out, map[string]any{"output": 2, "static": 2}) }) t.Run("data only mapping across branch, with interrupt after branch", func(t *testing.T) { wf := NewWorkflow[[]int, map[string]any]() wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input int) (output int, err error) { return input + 1, nil })). AddInputWithOptions(START, []*FieldMapping{ToField("", WithCustomExtractor(func(input any) (any, error) { inputList := input.([]int) if len(inputList) == 0 { return nil, fmt.Errorf("input list is empty") } return input.([]int)[0], nil }))}, WithNoDirectDependency()) wf.AddBranch(START, NewGraphBranch(func(ctx context.Context, in []int) (endNode string, err error) { if len(in) > 0 { return "1", nil } return END, nil }, map[string]bool{"1": true, END: true})) wf.End(). AddInput("1", ToField("output")). SetStaticValue(FieldPath{"static"}, 2) r, err := wf.Compile(context.Background(), WithCheckPointStore(newInMemoryStore()), WithInterruptBeforeNodes([]string{"1"})) assert.NoError(t, err) // skip lambda node "1" out, err := r.Invoke(context.Background(), nil) assert.NoError(t, err) assert.Equal(t, out, map[string]any{"static": 2}) // choose lambda node "1" _, err = r.Invoke(context.Background(), []int{1}, WithCheckPointID("123")) _, ok := ExtractInterruptInfo(err) assert.True(t, ok) out, err = r.Invoke(context.Background(), nil, WithCheckPointID("123")) assert.NoError(t, err) assert.Equal(t, out, map[string]any{"output": 2, "static": 2}) }) } ================================================ FILE: doc.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package eino provides building blocks for agent workflows, // tools, and composable graph utilities. package eino ================================================ FILE: flow/agent/agent_option.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package agent defines common option types used by agents and multi-agents. package agent import "github.com/cloudwego/eino/compose" // AgentOption is the common option type for various agent and multi-agent implementations. // For options intended to use with underlying graph or components, use WithComposeOptions to specify. // For options intended to use with particular agent/multi-agent implementations, use WrapImplSpecificOptFn to specify. type AgentOption struct { implSpecificOptFn any composeOptions []compose.Option } // GetComposeOptions returns all compose options from the given agent options. func GetComposeOptions(opts ...AgentOption) []compose.Option { var result []compose.Option for _, opt := range opts { result = append(result, opt.composeOptions...) } return result } // WithComposeOptions returns an agent option that specifies compose options. func WithComposeOptions(opts ...compose.Option) AgentOption { return AgentOption{ composeOptions: opts, } } // WrapImplSpecificOptFn returns an agent option that specifies a function to modify the implementation-specific options. func WrapImplSpecificOptFn[T any](optFn func(*T)) AgentOption { return AgentOption{ implSpecificOptFn: optFn, } } // GetImplSpecificOptions returns the implementation-specific options from the given agent options. func GetImplSpecificOptions[T any](base *T, opts ...AgentOption) *T { if base == nil { base = new(T) } for i := range opts { opt := opts[i] if opt.implSpecificOptFn != nil { optFn, ok := opt.implSpecificOptFn.(func(*T)) if ok { optFn(base) } } } return base } ================================================ FILE: flow/agent/multiagent/host/callback.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package host import ( "context" "fmt" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/schema" template "github.com/cloudwego/eino/utils/callbacks" ) // MultiAgentCallback is the callback interface for host multi-agent. type MultiAgentCallback interface { OnHandOff(ctx context.Context, info *HandOffInfo) context.Context } // HandOffInfo is the info which will be passed to MultiAgentCallback.OnHandOff, representing a hand off event. type HandOffInfo struct { ToAgentName string Argument string } // ConvertCallbackHandlers converts []host.MultiAgentCallback to callbacks.Handler. func ConvertCallbackHandlers(handlers ...MultiAgentCallback) callbacks.Handler { onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context { msg := output.Message if msg == nil || msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 { return ctx } for _, cb := range handlers { for _, toolCall := range msg.ToolCalls { ctx = cb.OnHandOff(ctx, &HandOffInfo{ ToAgentName: toolCall.Function.Name, Argument: toolCall.Function.Arguments, }) } } return ctx } onChatModelEndWithStreamOutput := func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { go func() { msg, err := schema.ConcatMessageStream(schema.StreamReaderWithConvert(output, func(m *model.CallbackOutput) (*schema.Message, error) { return m.Message, nil })) if err != nil { fmt.Printf("concat message stream for host multi-agent failed: %v", err) return } for _, cb := range handlers { for _, tc := range msg.ToolCalls { _ = cb.OnHandOff(ctx, &HandOffInfo{ ToAgentName: tc.Function.Name, Argument: tc.Function.Arguments, }) } } }() return ctx } return template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{ OnEnd: onChatModelEnd, OnEndWithStreamOutput: onChatModelEndWithStreamOutput, }).Handler() } // convertCallbacks reads graph call options, extract host.MultiAgentCallback and convert it to callbacks.Handler. func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler { agentOptions := agent.GetImplSpecificOptions(&options{}, opts...) if len(agentOptions.agentCallbacks) == 0 { return nil } handlers := agentOptions.agentCallbacks return ConvertCallbackHandlers(handlers...) } ================================================ FILE: flow/agent/multiagent/host/compose.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package host import ( "context" "fmt" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/schema" ) const ( defaultHostNodeKey = "host" // the key of the host node in the graph defaultHostPrompt = "decide which tool is best for the task and call only the best tool." specialistsAnswersCollectorNodeKey = "specialist_answers_collect" singleIntentAnswerNodeKey = "single_intent_answer" multiIntentSummarizeNodeKey = "multi_intents_summarize" defaultSummarizerPrompt = "summarize the answers from the specialists into a single answer." map2ListConverterNodeKey = "map_to_list" ) type state struct { msgs []*schema.Message isMultipleIntents bool } // NewMultiAgent creates a new host multi-agent system. // // IMPORTANT!! For models that don't output tool calls in the first streaming chunk (e.g. Claude) // the default StreamToolCallChecker may not work properly since it only checks the first chunk for tool calls. // In such cases, you need to implement a custom StreamToolCallChecker that can properly detect tool calls. func NewMultiAgent(ctx context.Context, config *MultiAgentConfig) (*MultiAgent, error) { if err := config.validate(); err != nil { return nil, err } hostKeyName := defaultHostNodeKey if config.HostNodeName != "" { hostKeyName = config.HostNodeName } var ( hostPrompt = config.Host.SystemPrompt name = config.Name toolCallChecker = config.StreamToolCallChecker ) if len(hostPrompt) == 0 { hostPrompt = defaultHostPrompt } if len(name) == 0 { name = "host multi agent" } if toolCallChecker == nil { toolCallChecker = firstChunkStreamToolCallChecker } g := compose.NewGraph[[]*schema.Message, *schema.Message]( compose.WithGenLocalState(func(context.Context) *state { return &state{} })) if err := g.AddPassthroughNode(specialistsAnswersCollectorNodeKey); err != nil { return nil, err } agentTools := make([]*schema.ToolInfo, 0, len(config.Specialists)) agentMap := make(map[string]bool, len(config.Specialists)+1) for i := range config.Specialists { specialist := config.Specialists[i] agentTools = append(agentTools, &schema.ToolInfo{ Name: specialist.Name, Desc: specialist.IntendedUse, ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ "reason": { Type: schema.String, Desc: "the reason to call this tool", }, }), }) if err := addSpecialistAgent(specialist, g); err != nil { return nil, err } agentMap[specialist.Name] = true } chatModel, err := agent.ChatModelWithTools(config.Host.ChatModel, config.Host.ToolCallingModel, agentTools) if err != nil { return nil, err } if err = addHostAgent(chatModel, hostPrompt, g, hostKeyName); err != nil { return nil, err } const convertorName = "msg2MsgList" if err = g.AddLambdaNode(convertorName, compose.ToList[*schema.Message](), compose.WithNodeName("converter")); err != nil { return nil, err } if err = addDirectAnswerBranch(convertorName, g, toolCallChecker); err != nil { return nil, err } if err = addMultiSpecialistsBranch(convertorName, agentMap, g); err != nil { return nil, err } if err = addSingleIntentAnswerNode(g); err != nil { return nil, err } if err = addMultiIntentsSummarizeNode(config.Summarizer, g); err != nil { return nil, err } if err = addAfterSpecialistsBranch(g); err != nil { return nil, err } compileOpts := []compose.GraphCompileOption{compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(name)} r, err := g.Compile(ctx, compileOpts...) if err != nil { return nil, err } return &MultiAgent{ runnable: r, graph: g, graphAddNodeOpts: []compose.GraphAddNodeOpt{compose.WithGraphCompileOptions(compileOpts...)}, }, nil } func addSpecialistAgent(specialist *Specialist, g *compose.Graph[[]*schema.Message, *schema.Message]) error { if specialist.Invokable != nil || specialist.Streamable != nil { lambda, err := compose.AnyLambda(specialist.Invokable, specialist.Streamable, nil, nil, compose.WithLambdaType("Specialist")) if err != nil { return err } preHandler := func(_ context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { return state.msgs, nil // replace the tool call message with input msgs stored in state } if err := g.AddLambdaNode(specialist.Name, lambda, compose.WithStatePreHandler(preHandler), compose.WithNodeName(specialist.Name), compose.WithOutputKey(specialist.Name)); err != nil { return err } } else if specialist.ChatModel != nil { preHandler := func(_ context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { if len(specialist.SystemPrompt) > 0 { return append([]*schema.Message{{ Role: schema.System, Content: specialist.SystemPrompt, }}, state.msgs...), nil } return state.msgs, nil // replace the tool call message with input msgs stored in state } if err := g.AddChatModelNode(specialist.Name, specialist.ChatModel, compose.WithStatePreHandler(preHandler), compose.WithNodeName(specialist.Name), compose.WithOutputKey(specialist.Name)); err != nil { return err } } return g.AddEdge(specialist.Name, specialistsAnswersCollectorNodeKey) } func addHostAgent(model model.BaseChatModel, prompt string, g *compose.Graph[[]*schema.Message, *schema.Message], hostNodeName string) error { preHandler := func(_ context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { state.msgs = input if len(prompt) == 0 { return input, nil } return append([]*schema.Message{{ Role: schema.System, Content: prompt, }}, input...), nil } if err := g.AddChatModelNode(defaultHostNodeKey, model, compose.WithStatePreHandler(preHandler), compose.WithNodeName(hostNodeName)); err != nil { return err } return g.AddEdge(compose.START, defaultHostNodeKey) } func addDirectAnswerBranch(convertorName string, g *compose.Graph[[]*schema.Message, *schema.Message], toolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error)) error { // handles the case where the host agent returns a direct answer, instead of handling off to any specialist branch := compose.NewStreamGraphBranch(func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) { isToolCall, err := toolCallChecker(ctx, sr) if err != nil { return "", err } if isToolCall { return convertorName, nil } return compose.END, nil }, map[string]bool{convertorName: true, compose.END: true}) return g.AddBranch(defaultHostNodeKey, branch) } func addMultiSpecialistsBranch(convertorName string, agentMap map[string]bool, g *compose.Graph[[]*schema.Message, *schema.Message]) error { branch := compose.NewGraphMultiBranch(func(ctx context.Context, input []*schema.Message) (map[string]bool, error) { if len(input) != 1 { return nil, fmt.Errorf("host agent output %d messages, but expected 1", len(input)) } results := map[string]bool{} for _, toolCall := range input[0].ToolCalls { results[toolCall.Function.Name] = true } if len(results) > 1 { _ = compose.ProcessState(ctx, func(_ context.Context, state *state) error { state.isMultipleIntents = true return nil }) } return results, nil }, agentMap) return g.AddBranch(convertorName, branch) } func addSingleIntentAnswerNode(g *compose.Graph[[]*schema.Message, *schema.Message]) error { rc := func(ctx context.Context, input *schema.StreamReader[map[string]any]) (*schema.StreamReader[*schema.Message], error) { return schema.StreamReaderWithConvert(input, func(msgs map[string]any) (*schema.Message, error) { if len(msgs) != 1 { return nil, fmt.Errorf("host agent output %d messages, but expected 1", len(msgs)) } for _, msg := range msgs { return msg.(*schema.Message), nil } return nil, schema.ErrNoValue }), nil } _ = g.AddLambdaNode(singleIntentAnswerNodeKey, compose.TransformableLambda(rc)) return g.AddEdge(singleIntentAnswerNodeKey, compose.END) } func addAfterSpecialistsBranch(g *compose.Graph[[]*schema.Message, *schema.Message]) error { ab := func(ctx context.Context, _ *schema.StreamReader[map[string]any]) (string, error) { var isMultipleIntents bool _ = compose.ProcessState(ctx, func(_ context.Context, state *state) error { isMultipleIntents = state.isMultipleIntents return nil }) if !isMultipleIntents { return singleIntentAnswerNodeKey, nil } return map2ListConverterNodeKey, nil } b := compose.NewStreamGraphBranch(ab, map[string]bool{ singleIntentAnswerNodeKey: true, map2ListConverterNodeKey: true, }) return g.AddBranch(specialistsAnswersCollectorNodeKey, b) } func addMultiIntentsSummarizeNode(summarizer *Summarizer, g *compose.Graph[[]*schema.Message, *schema.Message]) error { map2list := func(ctx context.Context, input map[string]any) ([]*schema.Message, error) { var output []*schema.Message for k := range input { output = append(output, input[k].(*schema.Message)) } return output, nil } _ = g.AddLambdaNode(map2ListConverterNodeKey, compose.InvokableLambda(map2list)) if summarizer != nil { _ = g.AddChatModelNode(multiIntentSummarizeNodeKey, summarizer.ChatModel, compose.WithStatePreHandler(func(ctx context.Context, in []*schema.Message, state *state) ([]*schema.Message, error) { var ( out []*schema.Message systemPrompt = defaultSummarizerPrompt ) if summarizer.SystemPrompt != "" { systemPrompt = summarizer.SystemPrompt } out = append(out, &schema.Message{ Role: schema.System, Content: systemPrompt, }) out = append(out, state.msgs...) out = append(out, in...) return out, nil })) _ = g.AddEdge(map2ListConverterNodeKey, multiIntentSummarizeNodeKey) return g.AddEdge(multiIntentSummarizeNodeKey, compose.END) } s := func(ctx context.Context, input []*schema.Message) (*schema.Message, error) { output := &schema.Message{ Role: schema.Assistant, } for _, msg := range input { output.Content += msg.Content + "\n" } return output, nil } _ = g.AddLambdaNode(multiIntentSummarizeNodeKey, compose.InvokableLambda(s)) _ = g.AddEdge(map2ListConverterNodeKey, multiIntentSummarizeNodeKey) return g.AddEdge(multiIntentSummarizeNodeKey, compose.END) } ================================================ FILE: flow/agent/multiagent/host/compose_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package host import ( "context" "io" "sync" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) func TestHostMultiAgent(t *testing.T) { ctrl := gomock.NewController(t) mockHostLLM := model.NewMockToolCallingChatModel(ctrl) mockSpecialistLLM1 := model.NewMockChatModel(ctrl) specialist1 := &Specialist{ ChatModel: mockSpecialistLLM1, SystemPrompt: "You are a helpful assistant.", AgentMeta: AgentMeta{ Name: "specialist 1", IntendedUse: "do stuff that works", }, } specialist2Msg1 := &schema.Message{ Role: schema.Assistant, Content: "specialist2", } specialist2Msg2 := &schema.Message{ Role: schema.Assistant, Content: " stream answer", } specialist2 := &Specialist{ Invokable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) { return &schema.Message{ Role: schema.Assistant, Content: "specialist2 invoke answer", }, nil }, Streamable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) { return schema.StreamReaderFromArray([]*schema.Message{specialist2Msg1, specialist2Msg2}), nil }, AgentMeta: AgentMeta{ Name: "specialist 2", IntendedUse: "do stuff that works too", }, } ctx := context.Background() mockHostLLM.EXPECT().WithTools(gomock.Any()).Return(mockHostLLM, nil).AnyTimes() hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{ Host: Host{ ToolCallingModel: mockHostLLM, }, Specialists: []*Specialist{ specialist1, specialist2, }, }) assert.NoError(t, err) t.Run("generate direct answer from host", func(t *testing.T) { directAnswerMsg := &schema.Message{ Role: schema.Assistant, Content: "direct answer", } mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(directAnswerMsg, nil).Times(1) mockCallback := newMockAgentCallback(0) out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback)) assert.NoError(t, err) assert.Equal(t, "direct answer", out.Content) assert.Empty(t, mockCallback.infos) }) t.Run("stream direct answer from host", func(t *testing.T) { directAnswerMsg1 := &schema.Message{ Role: schema.Assistant, Content: "direct ", } directAnswerMsg2 := &schema.Message{ Role: schema.Assistant, Content: "answer", } sr, sw := schema.Pipe[*schema.Message](0) go func() { sw.Send(directAnswerMsg1, nil) sw.Send(directAnswerMsg2, nil) sw.Close() }() mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) mockCallback := newMockAgentCallback(0) outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback)) assert.NoError(t, err) assert.Empty(t, mockCallback.infos) var msgs []*schema.Message for { msg, err := outStream.Recv() if err == io.EOF { break } assert.NoError(t, err) msgs = append(msgs, msg) } outStream.Close() assert.Equal(t, directAnswerMsg1, msgs[0]) assert.Equal(t, directAnswerMsg2, msgs[1]) }) t.Run("generate hand off", func(t *testing.T) { handOffMsg := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), Function: schema.FunctionCall{ Name: specialist1.Name, Arguments: `{"reason": "specialist 1 is the best"}`, }, }, }, } specialistMsg := &schema.Message{ Role: schema.Assistant, Content: "specialist 1 answer", } mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1) mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1) mockCallback := newMockAgentCallback(1) out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback)) assert.NoError(t, err) assert.Equal(t, "specialist 1 answer", out.Content) mockCallback.wg.Wait() assert.Equal(t, []*HandOffInfo{ { ToAgentName: specialist1.Name, Argument: `{"reason": "specialist 1 is the best"}`, }, }, mockCallback.infos) handOffMsg.ToolCalls[0].Function.Name = specialist2.Name handOffMsg.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}` mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1) mockCallback = newMockAgentCallback(1) out, err = hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback)) assert.NoError(t, err) assert.Equal(t, "specialist2 invoke answer", out.Content) mockCallback.wg.Wait() assert.Equal(t, []*HandOffInfo{ { ToAgentName: specialist2.Name, Argument: `{"reason": "specialist 2 is even better"}`, }, }, mockCallback.infos) }) t.Run("stream hand off to chat model", func(t *testing.T) { handOffMsg1 := &schema.Message{ Role: schema.Assistant, Content: "need to call function", } handOffMsg2 := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), }, }, } handOffMsg3 := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), Function: schema.FunctionCall{}, }, }, } handOffMsg4 := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), Function: schema.FunctionCall{ Name: specialist1.Name, Arguments: `{"reason": "specialist 1 is the best"}`, }, }, }, } sr, sw := schema.Pipe[*schema.Message](0) go func() { sw.Send(handOffMsg1, nil) sw.Send(handOffMsg2, nil) sw.Send(handOffMsg3, nil) sw.Send(handOffMsg4, nil) sw.Close() }() specialistMsg1 := &schema.Message{ Role: schema.Assistant, Content: "specialist ", } specialistMsg2 := &schema.Message{ Role: schema.Assistant, Content: "1 answer", } sr1, sw1 := schema.Pipe[*schema.Message](0) go func() { sw1.Send(specialistMsg1, nil) sw1.Send(specialistMsg2, nil) sw1.Close() }() streamToolCallChecker := func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error) { defer modelOutput.Close() for { msg, err := modelOutput.Recv() if err != nil { if err == io.EOF { return false, nil } return false, err } if len(msg.ToolCalls) == 0 { continue } if len(msg.ToolCalls) > 0 { return true, nil } } } hostMA, err = NewMultiAgent(ctx, &MultiAgentConfig{ Host: Host{ ToolCallingModel: mockHostLLM, }, Specialists: []*Specialist{ specialist1, specialist2, }, StreamToolCallChecker: streamToolCallChecker, }) assert.NoError(t, err) mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) mockSpecialistLLM1.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr1, nil).Times(1) mockCallback := newMockAgentCallback(1) outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback)) assert.NoError(t, err) var msgs []*schema.Message for { msg, err := outStream.Recv() if err == io.EOF { break } assert.NoError(t, err) msgs = append(msgs, msg) } outStream.Close() assert.Equal(t, specialistMsg1, msgs[0]) assert.Equal(t, specialistMsg2, msgs[1]) mockCallback.wg.Wait() assert.Equal(t, []*HandOffInfo{ { ToAgentName: specialist1.Name, Argument: `{"reason": "specialist 1 is the best"}`, }, }, mockCallback.infos) handOffMsg4.ToolCalls[0].Function.Name = specialist2.Name handOffMsg4.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}` sr, sw = schema.Pipe[*schema.Message](0) go func() { sw.Send(handOffMsg1, nil) sw.Send(handOffMsg2, nil) sw.Send(handOffMsg3, nil) sw.Send(handOffMsg4, nil) sw.Close() }() mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) mockCallback = newMockAgentCallback(1) outStream, err = hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback)) assert.NoError(t, err) msgs = nil for { msg, err := outStream.Recv() if err == io.EOF { break } assert.NoError(t, err) msgs = append(msgs, msg) } outStream.Close() assert.Equal(t, specialist2Msg1, msgs[0]) assert.Equal(t, specialist2Msg2, msgs[1]) mockCallback.wg.Wait() assert.Equal(t, []*HandOffInfo{ { ToAgentName: specialist2.Name, Argument: `{"reason": "specialist 2 is even better"}`, }, }, mockCallback.infos) }) t.Run("multi-agent within graph", func(t *testing.T) { handOffMsg := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), Function: schema.FunctionCall{ Name: specialist1.Name, Arguments: `{"reason": "specialist 1 is the best"}`, }, }, }, } specialistMsg := &schema.Message{ Role: schema.Assistant, Content: "Beijing", } mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1) mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1) mockCallback := newMockAgentCallback(1) hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{ Host: Host{ ToolCallingModel: mockHostLLM, }, Specialists: []*Specialist{ specialist1, specialist2, }, }) assert.NoError(t, err) maGraph, opts := hostMA.ExportGraph() fullGraph, err := compose.NewChain[map[string]any, *schema.Message](). AppendChatTemplate(prompt.FromMessages(schema.FString, schema.UserMessage("what's the capital city of {country_name}"))). AppendGraph(maGraph, append(opts, compose.WithNodeKey("host_ma_node"))...). Compile(ctx) assert.NoError(t, err) out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, compose.WithCallbacks(ConvertCallbackHandlers(mockCallback)).DesignateNodeWithPath(compose.NewNodePath("host_ma_node", hostMA.HostNodeKey()))) assert.NoError(t, err) assert.Equal(t, "Beijing", out.Content) mockCallback.wg.Wait() assert.Equal(t, []*HandOffInfo{ { ToAgentName: specialist1.Name, Argument: `{"reason": "specialist 1 is the best"}`, }, }, mockCallback.infos) }) t.Run("multiple intents", func(t *testing.T) { handOffMsg1 := &schema.Message{ Role: schema.Assistant, Content: "need to call function", } handOffMsg2 := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), }, }, } handOffMsg3 := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), Function: schema.FunctionCall{}, }, }, } handOffMsg4 := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), Function: schema.FunctionCall{ Name: specialist1.Name, Arguments: `{"reason": "specialist 1 is good"}`, }, }, { Index: generic.PtrOf(1), Function: schema.FunctionCall{ Name: specialist2.Name, Arguments: `{"reason": "specialist 2`, }, }, }, } handOffMsg5 := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(1), Function: schema.FunctionCall{ Arguments: ` is also good"}`, }, }, }, } sr := schema.StreamReaderFromArray([]*schema.Message{ handOffMsg1, handOffMsg2, handOffMsg3, handOffMsg4, handOffMsg5, }) specialist1Msg1 := &schema.Message{ Role: schema.Assistant, Content: "specialist ", } specialist1Msg2 := &schema.Message{ Role: schema.Assistant, Content: "1 answer", } sr1 := schema.StreamReaderFromArray([]*schema.Message{ specialist1Msg1, specialist1Msg2, }) streamToolCallChecker := func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error) { defer modelOutput.Close() for { msg, err := modelOutput.Recv() if err != nil { if err == io.EOF { return false, nil } return false, err } if len(msg.ToolCalls) == 0 { continue } if len(msg.ToolCalls) > 0 { return true, nil } } } hostMA, err = NewMultiAgent(ctx, &MultiAgentConfig{ Host: Host{ ToolCallingModel: mockHostLLM, }, Specialists: []*Specialist{ specialist1, specialist2, }, StreamToolCallChecker: streamToolCallChecker, }) assert.NoError(t, err) mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) mockSpecialistLLM1.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr1, nil).Times(1) mockCallback := newMockAgentCallback(2) outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback)) assert.NoError(t, err) var msgs []*schema.Message for { msg, err := outStream.Recv() if err == io.EOF { break } assert.NoError(t, err) msgs = append(msgs, msg) } outStream.Close() msg, err := schema.ConcatMessages(msgs) assert.NoError(t, err) if msg.Content != "specialist2 stream answer\nspecialist 1 answer\n" && msg.Content != "specialist 1 answer\nspecialist2 stream answer\n" { t.Errorf("Unexpected message content: %s", msg.Content) } mockCallback.wg.Wait() assert.Equal(t, []*HandOffInfo{ { ToAgentName: specialist1.Name, Argument: `{"reason": "specialist 1 is good"}`, }, { ToAgentName: specialist2.Name, Argument: `{"reason": "specialist 2 is also good"}`, }, }, mockCallback.infos) }) t.Run("summarize multiple intents", func(t *testing.T) { handOffMsg := &schema.Message{ Role: schema.Assistant, ToolCalls: []schema.ToolCall{ { Index: generic.PtrOf(0), Function: schema.FunctionCall{ Name: specialist1.Name, Arguments: `{"reason": "specialist 1 is good"}`, }, }, { Index: generic.PtrOf(1), Function: schema.FunctionCall{ Name: specialist2.Name, Arguments: `{"reason": "specialist 2 is also good"}`, }, }, }, } sr := schema.StreamReaderFromArray([]*schema.Message{ handOffMsg, }) specialist1Msg1 := &schema.Message{ Role: schema.Assistant, Content: "specialist 1 answer", } sr1 := schema.StreamReaderFromArray([]*schema.Message{ specialist1Msg1, }) const summaryContent = "summarized answer" sr2 := schema.StreamReaderFromArray([]*schema.Message{ { Role: schema.Assistant, Content: summaryContent, }, }) mockSumChatModel := model.NewMockChatModel(ctrl) hostMA, err = NewMultiAgent(ctx, &MultiAgentConfig{ Host: Host{ ToolCallingModel: mockHostLLM, }, Specialists: []*Specialist{ specialist1, specialist2, }, Summarizer: &Summarizer{ ChatModel: mockSumChatModel, }, }) assert.NoError(t, err) mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) mockSpecialistLLM1.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr1, nil).Times(1) mockSumChatModel.EXPECT().Stream(gomock.Any(), gomock.Cond(func(x any) bool { return assert.Equal(t, defaultSummarizerPrompt, func() string { if input := x.([]*schema.Message); len(input) > 0 { return input[0].Content } return "" }()) })).Return(sr2, nil).Times(1) outStream, err := hostMA.Stream(ctx, nil) assert.NoError(t, err) var msgs []*schema.Message for { msg, err := outStream.Recv() if err == io.EOF { break } assert.NoError(t, err) msgs = append(msgs, msg) } outStream.Close() msg, err := schema.ConcatMessages(msgs) assert.NoError(t, err) if msg.Content != summaryContent { t.Errorf("Unexpected message content: %s", msg.Content) } }, ) } type mockAgentCallback struct { infos []*HandOffInfo wg sync.WaitGroup } func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) context.Context { m.infos = append(m.infos, info) m.wg.Done() return ctx } func newMockAgentCallback(expects int) *mockAgentCallback { m := &mockAgentCallback{ infos: make([]*HandOffInfo, 0), wg: sync.WaitGroup{}, } m.wg.Add(expects) return m } ================================================ FILE: flow/agent/multiagent/host/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package host ================================================ FILE: flow/agent/multiagent/host/options.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package host import "github.com/cloudwego/eino/flow/agent" type options struct { agentCallbacks []MultiAgentCallback } // WithAgentCallbacks registers callbacks to be invoked by the host multi-agent. func WithAgentCallbacks(agentCallbacks ...MultiAgentCallback) agent.AgentOption { return agent.WrapImplSpecificOptFn(func(opts *options) { opts.agentCallbacks = append(opts.agentCallbacks, agentCallbacks...) }) } ================================================ FILE: flow/agent/multiagent/host/types.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package host implements the host pattern for multi-agent system. package host import ( "context" "errors" "fmt" "io" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/schema" ) // MultiAgent is a host multi-agent system. // A host agent is responsible for deciding which specialist to 'hand off' the task to. // One or more specialist agents are responsible for completing the task. type MultiAgent struct { runnable compose.Runnable[[]*schema.Message, *schema.Message] graph *compose.Graph[[]*schema.Message, *schema.Message] graphAddNodeOpts []compose.GraphAddNodeOpt } // Generate runs the multi-agent synchronously and returns the final message. func (ma *MultiAgent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) { composeOptions := agent.GetComposeOptions(opts...) handler := convertCallbacks(opts...) if handler != nil { composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(ma.HostNodeKey())) } return ma.runnable.Invoke(ctx, input, composeOptions...) } // Stream runs the multi-agent in streaming mode and returns a message stream. func (ma *MultiAgent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) { composeOptions := agent.GetComposeOptions(opts...) handler := convertCallbacks(opts...) if handler != nil { composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(ma.HostNodeKey())) } return ma.runnable.Stream(ctx, input, composeOptions...) } // ExportGraph exports the underlying graph from MultiAgent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph. func (ma *MultiAgent) ExportGraph() (compose.AnyGraph, []compose.GraphAddNodeOpt) { return ma.graph, ma.graphAddNodeOpts } // HostNodeKey returns the graph node key used for the host agent. func (ma *MultiAgent) HostNodeKey() string { return defaultHostNodeKey } // MultiAgentConfig is the config for host multi-agent system. type MultiAgentConfig struct { Host Host Specialists []*Specialist Name string // the name of the host multi-agent HostNodeName string // the name of the host node in the graph, default is "host" // StreamToolCallChecker is a function to determine whether the model's streaming output contains tool calls. // Different models have different ways of outputting tool calls in streaming mode: // - Some models (like OpenAI) output tool calls directly // - Others (like Claude) output text first, then tool calls // This handler allows custom logic to check for tool calls in the stream. // It should return: // - true if the output contains tool calls and agent should continue processing // - false if no tool calls and agent should stop // Note: This field only needs to be configured when using streaming mode // Note: The handler MUST close the modelOutput stream before returning // Optional. By default, it checks if the first chunk contains tool calls. // Note: The default implementation does not work well with Claude, which typically outputs tool calls after text content. // Note: If your ChatModel doesn't output tool calls first, you can try adding prompts to constrain the model from generating extra text during the tool call. StreamToolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error) // Summarizer is the summarizer agent that will summarize the outputs of all the chosen specialist agents. // Only when the Host agent picks multiple Specialist will this be called. // If you do not provide a summarizer, a default summarizer that simply concatenates all the output messages into one message will be used. // Note: the default summarizer do not support streaming. Summarizer *Summarizer } func (conf *MultiAgentConfig) validate() error { if conf == nil { return errors.New("host multi agent config is nil") } if conf.Host.ChatModel == nil && conf.Host.ToolCallingModel == nil { return errors.New("host multi agent host ChatModel is nil") } if len(conf.Specialists) == 0 { return errors.New("host multi agent specialists are empty") } for _, s := range conf.Specialists { if s.ChatModel == nil && s.Invokable == nil && s.Streamable == nil { return fmt.Errorf("specialist %s has no chat model or Invokable or Streamable", s.Name) } if err := s.AgentMeta.validate(); err != nil { return err } } return nil } // AgentMeta is the meta information of an agent within a multi-agent system. type AgentMeta struct { Name string // the name of the agent, should be unique within multi-agent system IntendedUse string // the intended use-case of the agent, used as the reason for the multi-agent system to hand over control to this agent } func (am AgentMeta) validate() error { if len(am.Name) == 0 { return errors.New("agent meta name is empty") } if len(am.IntendedUse) == 0 { return errors.New("agent meta intended use is empty") } return nil } // Host is the host agent within a multi-agent system. // Currently, it can only be a model.ChatModel. type Host struct { ToolCallingModel model.ToolCallingChatModel // Deprecated: ChatModel is deprecated, please use ToolCallingModel instead. // This field will be removed in a future release. ChatModel model.ChatModel SystemPrompt string } // Specialist is a specialist agent within a host multi-agent system. // It can be a model.ChatModel or any Invokable and/or Streamable, such as react.Agent. // ChatModel and (Invokable / Streamable) are mutually exclusive, only one should be provided. // notice: SystemPrompt only effects when ChatModel has been set. // If Invokable is provided but not Streamable, then the Specialist will be 'compose.InvokableLambda'. // If Streamable is provided but not Invokable, then the Specialist will be 'compose.StreamableLambda'. // if Both Invokable and Streamable is provided, then the Specialist will be 'compose.AnyLambda'. type Specialist struct { AgentMeta ChatModel model.BaseChatModel SystemPrompt string Invokable compose.Invoke[[]*schema.Message, *schema.Message, agent.AgentOption] Streamable compose.Stream[[]*schema.Message, *schema.Message, agent.AgentOption] } // Summarizer defines a lightweight agent used to summarize // conversations or tool outputs using a chat model and prompt. type Summarizer struct { ChatModel model.BaseChatModel SystemPrompt string } func firstChunkStreamToolCallChecker(_ context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) { defer sr.Close() for { msg, err := sr.Recv() if err == io.EOF { return false, nil } if err != nil { return false, err } if len(msg.ToolCalls) > 0 { return true, nil } if len(msg.Content) == 0 { // skip empty chunks at the front continue } return false, nil } } ================================================ FILE: flow/agent/react/callback.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package react provides helpers to build callback handlers for React agents. package react import ( "github.com/cloudwego/eino/callbacks" template "github.com/cloudwego/eino/utils/callbacks" ) // BuildAgentCallback builds a callback handler for agent. // e.g. // // callback := BuildAgentCallback(modelHandler, toolHandler) // agent, err := react.NewAgent(ctx, &AgentConfig{}) // agent.Generate(ctx, input, agent.WithComposeOptions(compose.WithCallbacks(callback))) func BuildAgentCallback(modelHandler *template.ModelCallbackHandler, toolHandler *template.ToolCallbackHandler) callbacks.Handler { return template.NewHandlerHelper().ChatModel(modelHandler).Tool(toolHandler).Handler() } ================================================ FILE: flow/agent/react/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package react ================================================ FILE: flow/agent/react/option.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package react import ( "context" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/schema" ub "github.com/cloudwego/eino/utils/callbacks" ) // WithToolOptions returns an agent option that specifies tool.Option for the tools in agent. func WithToolOptions(opts ...tool.Option) agent.AgentOption { return agent.WithComposeOptions(compose.WithToolsNodeOption(compose.WithToolOption(opts...))) } // WithChatModelOptions returns an agent option that specifies model.Option for the chat model in agent. func WithChatModelOptions(opts ...model.Option) agent.AgentOption { return agent.WithComposeOptions(compose.WithChatModelOption(opts...)) } // WithToolList returns an agent option that specifies compose.ToolsNodeOption for ToolsNode in agent. // If you also need to pass ToolInfo to the chat model, use WithTools instead. // Deprecated: This changes tool list for ToolsNode ONLY. func WithToolList(tools ...tool.BaseTool) agent.AgentOption { return agent.WithComposeOptions(compose.WithToolsNodeOption(compose.WithToolList(tools...))) } // WithTools is a convenience function that configures a React agent with a list of tools. // It performs two essential operations: // 1. Extracts tool information for the chat model to understand available tools // 2. Registers the actual tool implementations for execution // // Parameters: // - ctx: The context for the operation, used when calling Info() on each tool // - tools: A variadic list of tools that must implement either InvokableTool or StreamableTool interfaces // // Returns: // - []agent.AgentOption: A slice containing exactly 2 agent options: // - Option 1: Configures the chat model with tool schemas via model.WithTools(toolInfos) // - Option 2: Registers the tool implementations via compose.WithToolList(tools...) // - error: Returns an error if any tool's Info() method fails // // Usage Example: // // ctx := context.Background() // agentOptions, err := WithTools(ctx, myTool1, myTool2, myTool3) // if err != nil { // return fmt.Errorf("failed to configure tools: %w", err) // } // // agent, err := react.NewAgent(ctx, &react.AgentConfig{ // ToolCallingModel: myModel, // // other config... // }) // if err != nil { // return fmt.Errorf("failed to create agent: %w", err) // } // // // Use the tool options with Generate or Stream methods // msg, err := agent.Generate(ctx, messages, agentOptions...) // // or // stream, err := agent.Stream(ctx, messages, agentOptions...) // // Comparison with Related Functions: // - WithToolList: Only registers tool implementations, doesn't configure the chat model // - WithTools: Comprehensive setup that handles both chat model configuration and tool registration // // Notes: // - The function always returns exactly 2 options when successful // - Both returned options should be applied to the agent for proper tool functionality func WithTools(ctx context.Context, tools ...tool.BaseTool) ([]agent.AgentOption, error) { toolInfos := make([]*schema.ToolInfo, 0, len(tools)) for _, tl := range tools { info, err := tl.Info(ctx) if err != nil { return nil, err } toolInfos = append(toolInfos, info) } opts := make([]agent.AgentOption, 2) opts[0] = agent.WithComposeOptions(compose.WithChatModelOption(model.WithTools(toolInfos))) opts[1] = agent.WithComposeOptions(compose.WithToolsNodeOption(compose.WithToolList(tools...))) return opts, nil } // Iterator provides a lightweight FIFO stream of values and errors // produced during agent execution. type Iterator[T any] struct { ch *internal.UnboundedChan[item[T]] } // Next retrieves the next value from the iterator. // It returns the zero value and false when the stream is exhausted. func (iter *Iterator[T]) Next() (T, bool, error) { ch := iter.ch if ch == nil { var zero T return zero, false, nil } i, ok := ch.Receive() if !ok { var zero T return zero, false, nil } return i.v, true, i.err } // MessageFuture exposes asynchronous accessors for messages produced // by Generate and Stream calls. type MessageFuture interface { // GetMessages returns an iterator for retrieving messages generated during "agent.Generate" calls. GetMessages() *Iterator[*schema.Message] // GetMessageStreams returns an iterator for retrieving streaming messages generated during "agent.Stream" calls. GetMessageStreams() *Iterator[*schema.StreamReader[*schema.Message]] } // WithMessageFuture returns an agent option and a MessageFuture interface instance. // The option configures the agent to collect messages generated during execution, // while the MessageFuture interface allows users to asynchronously retrieve these messages. func WithMessageFuture() (agent.AgentOption, MessageFuture) { h := &cbHandler{started: make(chan struct{})} cmHandler := &ub.ModelCallbackHandler{ OnEnd: h.onChatModelEnd, OnEndWithStreamOutput: h.onChatModelEndWithStreamOutput, } createToolResultSender := func() toolResultSender { return func(toolName, callID, result string) { msg := schema.ToolMessage(result, callID, schema.WithToolName(toolName)) h.sendMessage(msg) } } createStreamToolResultSender := func() streamToolResultSender { return func(toolName, callID string, resultStream *schema.StreamReader[string]) { cvt := func(in string) (*schema.Message, error) { return schema.ToolMessage(in, callID, schema.WithToolName(toolName)), nil } msgStream := schema.StreamReaderWithConvert(resultStream, cvt) h.sendMessageStream(msgStream) } } createEnhancedToolResultSender := func() enhancedToolResultSender { return func(toolName, callID string, result *schema.ToolResult) { var err error msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) msg.UserInputMultiContent, err = result.ToMessageInputParts() if err != nil { return } h.sendMessage(msg) } } createEnhancedStreamToolResultSender := func() enhancedStreamToolResultSender { return func(toolName, callID string, resultStream *schema.StreamReader[*schema.ToolResult]) { cvt := func(result *schema.ToolResult) (*schema.Message, error) { var err error msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) msg.UserInputMultiContent, err = result.ToMessageInputParts() if err != nil { return nil, err } return msg, nil } msgStream := schema.StreamReaderWithConvert(resultStream, cvt) h.sendMessageStream(msgStream) } } graphHandler := callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { h.onGraphStart(ctx, info, input) return setToolResultSendersToCtx(ctx, &toolResultSenders{ sender: createToolResultSender(), streamSender: createStreamToolResultSender(), enhancedResultSender: createEnhancedToolResultSender(), enhancedStreamToolResultSender: createEnhancedStreamToolResultSender(), }) }). OnStartWithStreamInputFn(func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { h.onGraphStartWithStreamInput(ctx, info, input) return setToolResultSendersToCtx(ctx, &toolResultSenders{ sender: createToolResultSender(), streamSender: createStreamToolResultSender(), enhancedResultSender: createEnhancedToolResultSender(), enhancedStreamToolResultSender: createEnhancedStreamToolResultSender(), }) }). OnEndFn(h.onGraphEnd). OnEndWithStreamOutputFn(h.onGraphEndWithStreamOutput). OnErrorFn(h.onGraphError).Build() cb := ub.NewHandlerHelper().ChatModel(cmHandler).Graph(graphHandler).Handler() option := agent.WithComposeOptions(compose.WithCallbacks(cb)) return option, h } type item[T any] struct { v T err error } type cbHandler struct { msgs *internal.UnboundedChan[item[*schema.Message]] sMsgs *internal.UnboundedChan[item[*schema.StreamReader[*schema.Message]]] started chan struct{} } func (h *cbHandler) GetMessages() *Iterator[*schema.Message] { <-h.started return &Iterator[*schema.Message]{ch: h.msgs} } func (h *cbHandler) GetMessageStreams() *Iterator[*schema.StreamReader[*schema.Message]] { <-h.started return &Iterator[*schema.StreamReader[*schema.Message]]{ch: h.sMsgs} } func (h *cbHandler) onChatModelEnd(ctx context.Context, _ *callbacks.RunInfo, input *model.CallbackOutput) context.Context { h.sendMessage(input.Message) return ctx } func (h *cbHandler) onChatModelEndWithStreamOutput(ctx context.Context, _ *callbacks.RunInfo, input *schema.StreamReader[*model.CallbackOutput]) context.Context { c := func(output *model.CallbackOutput) (*schema.Message, error) { return output.Message, nil } s := schema.StreamReaderWithConvert(input, c) h.sendMessageStream(s) return ctx } func (h *cbHandler) onGraphError(ctx context.Context, _ *callbacks.RunInfo, err error) context.Context { if h.msgs != nil { h.msgs.Send(item[*schema.Message]{err: err}) } else { h.sMsgs.Send(item[*schema.StreamReader[*schema.Message]]{err: err}) } return ctx } func (h *cbHandler) onGraphEnd(ctx context.Context, _ *callbacks.RunInfo, _ callbacks.CallbackOutput) context.Context { h.msgs.Close() return ctx } func (h *cbHandler) onGraphEndWithStreamOutput(ctx context.Context, _ *callbacks.RunInfo, _ *schema.StreamReader[callbacks.CallbackOutput]) context.Context { h.sMsgs.Close() return ctx } func (h *cbHandler) onGraphStart(ctx context.Context, _ *callbacks.RunInfo, _ callbacks.CallbackInput) context.Context { h.msgs = internal.NewUnboundedChan[item[*schema.Message]]() close(h.started) return ctx } func (h *cbHandler) onGraphStartWithStreamInput(ctx context.Context, _ *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { input.Close() h.sMsgs = internal.NewUnboundedChan[item[*schema.StreamReader[*schema.Message]]]() close(h.started) return ctx } func (h *cbHandler) sendMessage(msg *schema.Message) { if h.msgs != nil { h.msgs.Send(item[*schema.Message]{v: msg}) } else { sMsg := schema.StreamReaderFromArray([]*schema.Message{msg}) h.sMsgs.Send(item[*schema.StreamReader[*schema.Message]]{v: sMsg}) } } func (h *cbHandler) sendMessageStream(sMsg *schema.StreamReader[*schema.Message]) { if h.sMsgs != nil { h.sMsgs.Send(item[*schema.StreamReader[*schema.Message]]{v: sMsg}) } else { // concat msg, err := schema.ConcatMessageStream(sMsg) if err != nil { h.msgs.Send(item[*schema.Message]{err: err}) } else { h.msgs.Send(item[*schema.Message]{v: msg}) } } } ================================================ FILE: flow/agent/react/option_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package react import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" ) func TestWithMessageFuture(t *testing.T) { ctx := context.Background() // Test with tool calls t.Run("test generate with tool calls", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) fakeTool := &fakeToolGreetForTest{} info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Mock model response with tool call cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"name": "test user"}`, }, }, }), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create agent with MessageFuture option, future := WithMessageFuture() a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, MaxStep: 3, }) assert.Nil(t, err) // Generate response response, err := a.Generate(ctx, []*schema.Message{ schema.UserMessage("use the greet tool"), }, option) assert.Nil(t, err) assert.Equal(t, "final response", response.Content) sIter := future.GetMessageStreams() // Should be no messages _, hasNext, err := sIter.Next() assert.Nil(t, err) assert.False(t, hasNext) iter := future.GetMessages() // First message should be the assistant message for tool calling msg1, hasNext, err := iter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.Equal(t, schema.Assistant, msg1.Role) assert.Equal(t, 1, len(msg1.ToolCalls)) // Second message should be the tool response msg2, hasNext, err := iter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.Equal(t, schema.Tool, msg2.Role) // Third message should be the final response msg3, hasNext, err := iter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.Equal(t, "final response", msg3.Content) // Should be no more messages _, hasNext, err = iter.Next() assert.Nil(t, err) assert.False(t, hasNext) }) // Test with streaming tool calls t.Run("test generate with streaming tool calls", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) fakeTool := &fakeStreamToolGreetForTest{} info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Mock model response with tool call cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"name": "test user"}`, }, }, }), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create agent with MessageFuture option, future := WithMessageFuture() a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, MaxStep: 3, }) assert.Nil(t, err) // Generate response response, err := a.Generate(ctx, []*schema.Message{ schema.UserMessage("use the greet tool"), }, option) assert.Nil(t, err) assert.Equal(t, "final response", response.Content) // Get messages from future iter := future.GetMessages() // First message should be the assistant message for tool calling msg1, hasNext, err := iter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.Equal(t, schema.Assistant, msg1.Role) assert.Equal(t, 1, len(msg1.ToolCalls)) // Second message should be the tool response msg2, hasNext, err := iter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.Equal(t, schema.Tool, msg2.Role) // Third message should be the final response msg3, hasNext, err := iter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.Equal(t, "final response", msg3.Content) // Should be no more messages _, hasNext, err = iter.Next() assert.Nil(t, err) assert.False(t, hasNext) }) // Test with non-streaming tool but using agent's Stream interface t.Run("test stream with tool calls", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) fakeTool := &fakeToolGreetForTest{} info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Mock model response with tool call cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"name": "test user"}`, }, }, })}), nil). Times(1) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final response", nil)}), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create agent with MessageFuture option, future := WithMessageFuture() a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, MaxStep: 3, }) assert.Nil(t, err) // Use Stream interface stream, err := a.Stream(ctx, []*schema.Message{ schema.UserMessage("use the greet tool"), }, option) assert.Nil(t, err) // Collect all chunks from stream finalResponse, err := schema.ConcatMessageStream(stream) assert.Nil(t, err) assert.Equal(t, "final response", finalResponse.Content) iter := future.GetMessages() // Should be no messages _, hasNext, err := iter.Next() assert.Nil(t, err) assert.False(t, hasNext) // Get message streams from future sIter := future.GetMessageStreams() // First message should be the assistant message for tool calling stream1, hasNext, err := sIter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.NotNil(t, stream1) msg1, err := schema.ConcatMessageStream(stream1) assert.Nil(t, err) assert.Equal(t, schema.Assistant, msg1.Role) assert.Equal(t, 1, len(msg1.ToolCalls)) // Second message should be the tool response stream2, hasNext, err := sIter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.NotNil(t, stream2) msg2, err := schema.ConcatMessageStream(stream2) assert.Nil(t, err) assert.Equal(t, schema.Tool, msg2.Role) // Third message should be the final response stream3, hasNext, err := sIter.Next() assert.Nil(t, err) assert.True(t, hasNext) assert.NotNil(t, stream3) msg3, err := schema.ConcatMessageStream(stream3) assert.Nil(t, err) assert.Equal(t, "final response", msg3.Content) // Should be no more messages _, hasNext, err = sIter.Next() assert.Nil(t, err) assert.False(t, hasNext) }) t.Run("test stream with streaming tool calls and with concurrent goroutines", func(t *testing.T) { ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) fakeTool := &fakeStreamToolGreetForTest{} info, err := fakeTool.Info(ctx) assert.NoError(t, err) // Mock model response with tool call cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"name": "test user"}`, }, }, })}), nil). Times(1) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final response", nil)}), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() // Create agent with MessageFuture option, future := WithMessageFuture() a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, MaxStep: 3, }) assert.Nil(t, err) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() // Get message streams from future sIter := future.GetMessageStreams() // First message should be the assistant message for tool calling stream1, hasNext, err_ := sIter.Next() assert.Nil(t, err_) assert.True(t, hasNext) assert.NotNil(t, stream1) msg1, err_ := schema.ConcatMessageStream(stream1) assert.Nil(t, err_) assert.Equal(t, schema.Assistant, msg1.Role) assert.Equal(t, 1, len(msg1.ToolCalls)) // Second message should be the tool response stream2, hasNext, err_ := sIter.Next() assert.Nil(t, err_) assert.True(t, hasNext) assert.NotNil(t, stream2) msg2, err_ := schema.ConcatMessageStream(stream2) assert.Nil(t, err_) assert.Equal(t, schema.Tool, msg2.Role) // Third message should be the final response stream3, hasNext, err_ := sIter.Next() assert.Nil(t, err_) assert.True(t, hasNext) assert.NotNil(t, stream3) msg3, err_ := schema.ConcatMessageStream(stream3) assert.Nil(t, err_) assert.Equal(t, "final response", msg3.Content) // Should be no more messages _, hasNext, err_ = sIter.Next() assert.Nil(t, err_) assert.False(t, hasNext) }() // Use Stream interface stream, err := a.Stream(ctx, []*schema.Message{ schema.UserMessage("use the greet tool"), }, option) assert.Nil(t, err) // Collect all chunks from stream finalResponse, err := schema.ConcatMessageStream(stream) assert.Nil(t, err) assert.Equal(t, "final response", finalResponse.Content) wg.Wait() }) } func TestWithToolOptions(t *testing.T) { type dummyOpt struct{ val string } opt := tool.WrapImplSpecificOptFn(func(o *dummyOpt) { o.val = "mock" }) agentOpt := WithToolOptions(opt) assert.NotNil(t, agentOpt) // The returned value should be an agent.AgentOption (function) assert.IsType(t, agentOpt, agentOpt) } func TestWithChatModelOptions(t *testing.T) { opt := model.WithModel("mock-model") agentOpt := WithChatModelOptions(opt) assert.NotNil(t, agentOpt) assert.IsType(t, agentOpt, agentOpt) } // dummyBaseTool is a minimal implementation of tool.BaseTool for testing. type dummyBaseTool struct{} func (d *dummyBaseTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "dummy"}, nil } func (d *dummyBaseTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { return "dummy-response", nil } type assertTool struct { toolOptVal string receivedToolOpt bool } type toolOpt struct{ val string } func (a *assertTool) Info(ctx context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{Name: "assert_tool"}, nil } func (a *assertTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { opt := tool.GetImplSpecificOptions(&toolOpt{}, opts...) if opt.val == a.toolOptVal { a.receivedToolOpt = true } return "tool-response", nil } func TestAgentWithAllOptions(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) // Prepare a tool that asserts it receives the tool option toolOptVal := "tool-opt-value" to := tool.WrapImplSpecificOptFn(func(o *toolOpt) { o.val = toolOptVal }) at := &assertTool{toolOptVal: toolOptVal} // Prepare a mock chat model that asserts it receives the model option cm := mockModel.NewMockToolCallingChatModel(ctrl) modelOpt := model.WithModel("test-model") modelOptReceived := false times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(_ context.Context, _ []*schema.Message, opts ...model.Option) (*schema.Message, error) { times++ if times == 1 { for _, o := range opts { opt := model.GetCommonOptions(&model.Options{}, o) if opt.Model != nil && *opt.Model == "test-model" { modelOptReceived = true } } info, _ := at.Info(ctx) return schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: "", }, }, }), nil } return schema.AssistantMessage("ok", nil), nil }, ).AnyTimes() cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() agentOpt := WithToolOptions(to) agentOpt2 := WithChatModelOptions(modelOpt) agentOpt3, err := WithTools(context.Background(), at) assert.NoError(t, err) a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{&dummyBaseTool{}}, }, MaxStep: 20, }) assert.NoError(t, err) _, err = a.Generate(ctx, []*schema.Message{ schema.UserMessage("call the tool"), }, agentOpt, agentOpt2, agentOpt3[0], agentOpt3[1]) assert.NoError(t, err) assert.True(t, modelOptReceived, "model option should be received by chat model") assert.True(t, at.receivedToolOpt, "tool option should be received by tool") } type simpleToolForMiddlewareTest struct { name string result string } func (s *simpleToolForMiddlewareTest) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: s.name, Desc: "simple tool for middleware test", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "input": { Desc: "input", Required: true, Type: schema.String, }, }), }, nil } func (s *simpleToolForMiddlewareTest) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return s.result, nil } func (s *simpleToolForMiddlewareTest) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) { return schema.StreamReaderFromArray([]string{s.result}), nil } func TestMessageFuture_ToolResultMiddleware_EmitsFinalResult(t *testing.T) { originalResult := "original_result" modifiedResult := "modified_by_middleware" resultModifyingMiddleware := compose.ToolMiddleware{ Invokable: func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } output.Result = modifiedResult return output, nil } }, Streamable: func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { output, err := next(ctx, input) if err != nil { return nil, err } output.Result = schema.StreamReaderFromArray([]string{modifiedResult}) return output, nil } }, } t.Run("Invoke", func(t *testing.T) { ctx := context.Background() testTool := &simpleToolForMiddlewareTest{name: "test_tool", result: originalResult} ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) info, err := testTool.Info(ctx) assert.NoError(t, err) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"input": "test"}`, }, }, }), nil). Times(1) cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.AssistantMessage("final response", nil), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() option, future := WithMessageFuture() a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, ToolCallMiddlewares: []compose.ToolMiddleware{resultModifyingMiddleware}, }, MaxStep: 3, }) assert.NoError(t, err) response, err := a.Generate(ctx, []*schema.Message{ schema.UserMessage("call the tool"), }, option) assert.NoError(t, err) assert.Equal(t, "final response", response.Content) iter := future.GetMessages() var allMsgs []*schema.Message for { msg, hasNext, err := iter.Next() if err != nil || !hasNext { break } allMsgs = append(allMsgs, msg) } assert.GreaterOrEqual(t, len(allMsgs), 3, "should have at least 3 messages") if len(allMsgs) >= 3 { assert.Equal(t, schema.Assistant, allMsgs[0].Role) assert.Equal(t, 1, len(allMsgs[0].ToolCalls)) assert.Equal(t, schema.Tool, allMsgs[1].Role) assert.Equal(t, modifiedResult, allMsgs[1].Content, "MessageFuture should receive the middleware-modified tool result") assert.NotEqual(t, originalResult, allMsgs[1].Content, "MessageFuture should NOT receive the original tool result") assert.Equal(t, "final response", allMsgs[2].Content) } }) t.Run("Stream", func(t *testing.T) { ctx := context.Background() testTool := &simpleToolForMiddlewareTest{name: "test_tool_stream", result: originalResult} ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) info, err := testTool.Info(ctx) assert.NoError(t, err) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("", []schema.ToolCall{ { ID: "tool-call-1", Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"input": "test"}`, }, }, }), }), nil). Times(1) cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). Return(schema.StreamReaderFromArray([]*schema.Message{ schema.AssistantMessage("final response", nil), }), nil). Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() option, future := WithMessageFuture() a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{testTool}, ToolCallMiddlewares: []compose.ToolMiddleware{resultModifyingMiddleware}, }, MaxStep: 3, }) assert.NoError(t, err) response, err := a.Stream(ctx, []*schema.Message{ schema.UserMessage("call the tool"), }, option) assert.NoError(t, err) var msgs []*schema.Message for { msg, err := response.Recv() if err != nil { break } msgs = append(msgs, msg) } finalMsg, err := schema.ConcatMessages(msgs) assert.NoError(t, err) assert.Equal(t, "final response", finalMsg.Content) iter := future.GetMessageStreams() var allMsgs []*schema.Message for { msgStream, hasNext, err := iter.Next() if err != nil || !hasNext { break } var streamMsgs []*schema.Message for { msg, err := msgStream.Recv() if err != nil { break } streamMsgs = append(streamMsgs, msg) } if len(streamMsgs) > 0 { concated, err := schema.ConcatMessages(streamMsgs) if err == nil { allMsgs = append(allMsgs, concated) } } } assert.GreaterOrEqual(t, len(allMsgs), 3, "should have at least 3 messages") if len(allMsgs) >= 3 { assert.Equal(t, schema.Assistant, allMsgs[0].Role) assert.Equal(t, 1, len(allMsgs[0].ToolCalls)) assert.Equal(t, schema.Tool, allMsgs[1].Role) assert.Equal(t, modifiedResult, allMsgs[1].Content, "MessageFuture should receive the middleware-modified tool result") assert.NotEqual(t, originalResult, allMsgs[1].Content, "MessageFuture should NOT receive the original tool result") assert.Equal(t, "final response", allMsgs[2].Content) } }) } ================================================ FILE: flow/agent/react/react.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package react import ( "context" "io" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/schema" ) type toolResultSender func(toolName, callID, result string) type enhancedToolResultSender func(toolName, callID string, result *schema.ToolResult) type streamToolResultSender func(toolName, callID string, resultStream *schema.StreamReader[string]) type enhancedStreamToolResultSender func(toolName, callID string, resultStream *schema.StreamReader[*schema.ToolResult]) type toolResultSenders struct { sender toolResultSender streamSender streamToolResultSender enhancedResultSender enhancedToolResultSender enhancedStreamToolResultSender enhancedStreamToolResultSender } type toolResultSenderCtxKey struct{} func setToolResultSendersToCtx(ctx context.Context, senders *toolResultSenders) context.Context { return context.WithValue(ctx, toolResultSenderCtxKey{}, senders) } func getToolResultSendersFromCtx(ctx context.Context) *toolResultSenders { v := ctx.Value(toolResultSenderCtxKey{}) if v == nil { return nil } return v.(*toolResultSenders) } type state struct { Messages []*schema.Message ReturnDirectlyToolCallID string } func init() { schema.RegisterName[*state]("_eino_react_state") } func newToolResultCollectorMiddleware() compose.ToolMiddleware { return compose.ToolMiddleware{ Invokable: func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { senders := getToolResultSendersFromCtx(ctx) output, err := next(ctx, input) if err != nil { return nil, err } if senders != nil && senders.sender != nil { senders.sender(input.Name, input.CallID, output.Result) } return output, nil } }, Streamable: func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { senders := getToolResultSendersFromCtx(ctx) output, err := next(ctx, input) if err != nil { return nil, err } if senders != nil && senders.streamSender != nil { streams := output.Result.Copy(2) senders.streamSender(input.Name, input.CallID, streams[0]) output.Result = streams[1] } return output, nil } }, EnhancedInvokable: func(next compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { senders := getToolResultSendersFromCtx(ctx) output, err := next(ctx, input) if err != nil { return nil, err } if senders != nil && senders.enhancedResultSender != nil { senders.enhancedResultSender(input.Name, input.CallID, output.Result) } return output, nil } }, EnhancedStreamable: func(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { senders := getToolResultSendersFromCtx(ctx) output, err := next(ctx, input) if err != nil { return nil, err } if senders != nil && senders.enhancedStreamToolResultSender != nil { streams := output.Result.Copy(2) senders.enhancedStreamToolResultSender(input.Name, input.CallID, streams[0]) output.Result = streams[1] } return output, nil } }, } } const ( nodeKeyTools = "tools" nodeKeyModel = "chat" ) // MessageModifier modify the input messages before the model is called. type MessageModifier func(ctx context.Context, input []*schema.Message) []*schema.Message // AgentConfig is the config for ReAct agent. type AgentConfig struct { // ToolCallingModel is the chat model to be used for handling user messages with tool calling capability. // This is the recommended model field to use. ToolCallingModel model.ToolCallingChatModel // Deprecated: Use ToolCallingModel instead. Model model.ChatModel // ToolsConfig is the config for tools node. ToolsConfig compose.ToolsNodeConfig // MessageModifier. // modify the input messages before the model is called, it's useful when you want to add some system prompt or other messages. MessageModifier MessageModifier // MessageRewriter modifies message in the state, before the ChatModel is called. // It takes the messages stored accumulated in state, modify them, and put the modified version back into state. // Useful for compressing message history to fit the model context window, // or if you want to make changes to messages that take effect across multiple model calls. // NOTE: if both MessageModifier and MessageRewriter are set, MessageRewriter will be called before MessageModifier. MessageRewriter MessageModifier // MaxStep. // default 12 of steps in pregel (node num + 10). MaxStep int `json:"max_step"` // Tools that will make agent return directly when the tool is called. // When multiple tools are called and more than one tool is in the return directly list, only the first one will be returned. ToolReturnDirectly map[string]struct{} // StreamToolCallChecker is a function to determine whether the model's streaming output contains tool calls. // Different models have different ways of outputting tool calls in streaming mode: // - Some models (like OpenAI) output tool calls directly // - Others (like Claude) output text first, then tool calls // This handler allows custom logic to check for tool calls in the stream. // It should return: // - true if the output contains tool calls and agent should continue processing // - false if no tool calls and agent should stop // Note: This field only needs to be configured when using streaming mode // Note: The handler MUST close the modelOutput stream before returning // Optional. By default, it checks if the first chunk contains tool calls. // Note: The default implementation does not work well with Claude, which typically outputs tool calls after text content. // Note: If your ChatModel doesn't output tool calls first, you can try adding prompts to constrain the model from generating extra text during the tool call. StreamToolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error) // GraphName is the graph name of the ReAct Agent. // Optional. Default `ReActAgent`. GraphName string // ModelNodeName is the node name of the model node in the ReAct Agent graph. // Optional. Default `ChatModel`. ModelNodeName string // ToolsNodeName is the node name of the tools node in the ReAct Agent graph. // Optional. Default `Tools`. ToolsNodeName string } // NewPersonaModifier returns a MessageModifier that adds a persona message to the input. // example: // // persona := "You are an expert in golang." // config := AgentConfig{ // Model: model, // MessageModifier: NewPersonaModifier(persona), // } // agent, err := NewAgent(ctx, config) // if err != nil {return} // msg, err := agent.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "how to build agent with eino"}}) // if err != nil {return} // println(msg.Content) // // Deprecated: Prefer directly including the persona message in the // input when calling Generate or Stream to avoid extra copying. func NewPersonaModifier(persona string) MessageModifier { return func(ctx context.Context, input []*schema.Message) []*schema.Message { res := make([]*schema.Message, 0, len(input)+1) res = append(res, schema.SystemMessage(persona)) res = append(res, input...) return res } } func firstChunkStreamToolCallChecker(_ context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) { defer sr.Close() for { msg, err := sr.Recv() if err == io.EOF { return false, nil } if err != nil { return false, err } if len(msg.ToolCalls) > 0 { return true, nil } if len(msg.Content) == 0 { // skip empty chunks at the front continue } return false, nil } } // Default graph and node names for the ReAct agent. const ( GraphName = "ReActAgent" ModelNodeName = "ChatModel" ToolsNodeName = "Tools" ) // SetReturnDirectly is a helper function that can be called within a tool's execution. // It signals the ReAct agent to stop further processing and return the result of the current tool call directly. // This is useful when the tool's output is the final answer and no more steps are needed. // Note: If multiple tools call this function in the same step, only the last call will take effect. // This setting has a higher priority than the AgentConfig.ToolReturnDirectly. func SetReturnDirectly(ctx context.Context) error { return compose.ProcessState(ctx, func(ctx context.Context, s *state) error { s.ReturnDirectlyToolCallID = compose.GetToolCallID(ctx) return nil }) } // Agent is the ReAct agent. // ReAct agent is a simple agent that handles user messages with a chat model and tools. // ReAct will call the chat model, if the message contains tool calls, it will call the tools. // if the tool is configured to return directly, ReAct will return directly. // otherwise, ReAct will continue to call the chat model until the message contains no tool calls. // e.g. // // agent, err := ReAct.NewAgent(ctx, &react.AgentConfig{}) // if err != nil {...} // msg, err := agent.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "how to build agent with eino"}}) // if err != nil {...} // println(msg.Content) type Agent struct { runnable compose.Runnable[[]*schema.Message, *schema.Message] graph *compose.Graph[[]*schema.Message, *schema.Message] graphAddNodeOpts []compose.GraphAddNodeOpt } // NewAgent creates a ReAct agent that feeds tool response into next round of Chat Model generation. // // IMPORTANT!! For models that don't output tool calls in the first streaming chunk (e.g. Claude) // the default StreamToolCallChecker may not work properly since it only checks the first chunk for tool calls. // In such cases, you need to implement a custom StreamToolCallChecker that can properly detect tool calls. func NewAgent(ctx context.Context, config *AgentConfig) (_ *Agent, err error) { var ( chatModel model.BaseChatModel toolsNode *compose.ToolsNode toolInfos []*schema.ToolInfo toolCallChecker = config.StreamToolCallChecker messageModifier = config.MessageModifier ) graphName := GraphName if config.GraphName != "" { graphName = config.GraphName } modelNodeName := ModelNodeName if config.ModelNodeName != "" { modelNodeName = config.ModelNodeName } toolsNodeName := ToolsNodeName if config.ToolsNodeName != "" { toolsNodeName = config.ToolsNodeName } if toolCallChecker == nil { toolCallChecker = firstChunkStreamToolCallChecker } if toolInfos, err = genToolInfos(ctx, config.ToolsConfig); err != nil { return nil, err } if chatModel, err = agent.ChatModelWithTools(config.Model, config.ToolCallingModel, toolInfos); err != nil { return nil, err } config.ToolsConfig.ToolCallMiddlewares = append( []compose.ToolMiddleware{newToolResultCollectorMiddleware()}, config.ToolsConfig.ToolCallMiddlewares..., ) if toolsNode, err = compose.NewToolNode(ctx, &config.ToolsConfig); err != nil { return nil, err } graph := compose.NewGraph[[]*schema.Message, *schema.Message](compose.WithGenLocalState(func(ctx context.Context) *state { return &state{Messages: make([]*schema.Message, 0, config.MaxStep+1)} })) modelPreHandle := func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { state.Messages = append(state.Messages, input...) if config.MessageRewriter != nil { state.Messages = config.MessageRewriter(ctx, state.Messages) } if messageModifier == nil { return state.Messages, nil } modifiedInput := make([]*schema.Message, len(state.Messages)) copy(modifiedInput, state.Messages) return messageModifier(ctx, modifiedInput), nil } if err = graph.AddChatModelNode(nodeKeyModel, chatModel, compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(modelNodeName)); err != nil { return nil, err } if err = graph.AddEdge(compose.START, nodeKeyModel); err != nil { return nil, err } toolsNodePreHandle := func(ctx context.Context, input *schema.Message, state *state) (*schema.Message, error) { if input == nil { return state.Messages[len(state.Messages)-1], nil // used for rerun interrupt resume } state.Messages = append(state.Messages, input) state.ReturnDirectlyToolCallID = getReturnDirectlyToolCallID(input, config.ToolReturnDirectly) return input, nil } if err = graph.AddToolsNode(nodeKeyTools, toolsNode, compose.WithStatePreHandler(toolsNodePreHandle), compose.WithNodeName(toolsNodeName)); err != nil { return nil, err } modelPostBranchCondition := func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) { if isToolCall, err := toolCallChecker(ctx, sr); err != nil { return "", err } else if isToolCall { return nodeKeyTools, nil } return compose.END, nil } if err = graph.AddBranch(nodeKeyModel, compose.NewStreamGraphBranch(modelPostBranchCondition, map[string]bool{nodeKeyTools: true, compose.END: true})); err != nil { return nil, err } if err = buildReturnDirectly(graph); err != nil { return nil, err } compileOpts := []compose.GraphCompileOption{compose.WithMaxRunSteps(config.MaxStep), compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(graphName)} runnable, err := graph.Compile(ctx, compileOpts...) if err != nil { return nil, err } return &Agent{ runnable: runnable, graph: graph, graphAddNodeOpts: []compose.GraphAddNodeOpt{compose.WithGraphCompileOptions(compileOpts...)}, }, nil } func buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message]) (err error) { directReturn := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) { return schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) { var msg *schema.Message err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error { for i := range msgs { if msgs[i] != nil && msgs[i].ToolCallID == state.ReturnDirectlyToolCallID { msg = msgs[i] return nil } } return nil }) if err != nil { return nil, err } if msg == nil { return nil, schema.ErrNoValue } return msg, nil }), nil } nodeKeyDirectReturn := "direct_return" if err = graph.AddLambdaNode(nodeKeyDirectReturn, compose.TransformableLambda(directReturn)); err != nil { return err } // this branch checks if the tool called should return directly. It either leads to END or back to ChatModel err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) { msgsStream.Close() err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error { if len(state.ReturnDirectlyToolCallID) > 0 { endNode = nodeKeyDirectReturn } else { endNode = nodeKeyModel } return nil }) if err != nil { return "", err } return endNode, nil }, map[string]bool{nodeKeyModel: true, nodeKeyDirectReturn: true})) if err != nil { return err } return graph.AddEdge(nodeKeyDirectReturn, compose.END) } func genToolInfos(ctx context.Context, config compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { toolInfos := make([]*schema.ToolInfo, 0, len(config.Tools)) for _, t := range config.Tools { tl, err := t.Info(ctx) if err != nil { return nil, err } toolInfos = append(toolInfos, tl) } return toolInfos, nil } func getReturnDirectlyToolCallID(input *schema.Message, toolReturnDirectly map[string]struct{}) string { if len(toolReturnDirectly) == 0 { return "" } for _, toolCall := range input.ToolCalls { if _, ok := toolReturnDirectly[toolCall.Function.Name]; ok { return toolCall.ID } } return "" } // Generate generates a response from the agent. func (r *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) { return r.runnable.Invoke(ctx, input, agent.GetComposeOptions(opts...)...) } // Stream calls the agent and returns a stream response. func (r *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (output *schema.StreamReader[*schema.Message], err error) { return r.runnable.Stream(ctx, input, agent.GetComposeOptions(opts...)...) } // ExportGraph exports the underlying graph from Agent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph. func (r *Agent) ExportGraph() (compose.AnyGraph, []compose.GraphAddNodeOpt) { return r.graph, r.graphAddNodeOpts } ================================================ FILE: flow/agent/react/react_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package react import ( "context" "errors" "fmt" "io" "math/rand" "testing" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/agent" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" template "github.com/cloudwego/eino/utils/callbacks" ) func TestReact(t *testing.T) { ctx := context.Background() fakeTool := &fakeToolGreetForTest{ tarCount: 3, } info, err := fakeTool.Info(ctx) assert.NoError(t, err) ctrl := gomock.NewController(t) cm := mockModel.NewMockChatModel(ctrl) times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { times++ if times <= 2 { info, _ := fakeTool.Info(ctx) return schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), }, }, }), nil } return schema.AssistantMessage("bye", nil), nil }).AnyTimes() cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() a, err := NewAgent(ctx, &AgentConfig{ Model: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, MessageModifier: func(ctx context.Context, input []*schema.Message) []*schema.Message { assert.Equal(t, len(input), times*2+1) return input }, MaxStep: 40, }) assert.Nil(t, err) out, err := a.Generate(ctx, []*schema.Message{ { Role: schema.User, Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) assert.Nil(t, err) if out != nil { t.Log(out.Content) } // test return directly times = 0 a, err = NewAgent(ctx, &AgentConfig{ Model: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, MessageModifier: func(ctx context.Context, input []*schema.Message) []*schema.Message { assert.Equal(t, len(input), times*2+1) return input }, MaxStep: 40, ToolReturnDirectly: map[string]struct{}{info.Name: {}}, }) assert.Nil(t, err) out, err = a.Generate(ctx, []*schema.Message{ { Role: schema.User, Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) assert.Nil(t, err) if out != nil { t.Log(out.Content) } } func TestReactWithMessageRewriterAndModifier(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) // This test simulates a single Generate call with a long history. // The MessageRewriter should shorten the history. // The MessageModifier should add a system prompt. cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { // Check messages passed to the model. // Expected: [system prompt, user: "message 2", assistant: "response 2"] assert.Len(t, input, 3) assert.Equal(t, schema.System, input[0].Role) assert.Equal(t, "system prompt", input[0].Content) assert.Equal(t, schema.User, input[1].Role) assert.Equal(t, "message 2", input[1].Content) assert.Equal(t, schema.Assistant, input[2].Role) assert.Equal(t, "response 2", input[2].Content) return schema.AssistantMessage("final response", nil), nil }).Times(1) cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() ra, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, MessageRewriter: func(ctx context.Context, messages []*schema.Message) []*schema.Message { // Keep only the last 2 messages if history is longer. assert.Len(t, messages, 4) // user1, assistant1, user2, assistant2 if len(messages) > 2 { return messages[len(messages)-2:] } return messages }, MessageModifier: func(ctx context.Context, messages []*schema.Message) []*schema.Message { // messages should be the result from rewriter assert.Len(t, messages, 2) // user2, assistant2 // Add a system prompt res := make([]*schema.Message, 0, len(messages)+1) res = append(res, schema.SystemMessage("system prompt")) res = append(res, messages...) return res }, }) assert.NoError(t, err) // Simulate a conversation history history := []*schema.Message{ schema.UserMessage("message 1"), schema.AssistantMessage("response 1", nil), schema.UserMessage("message 2"), schema.AssistantMessage("response 2", nil), } // Run the react agent finalMsg, err := ra.Generate(ctx, history) assert.NoError(t, err) assert.Equal(t, "final response", finalMsg.Content) } func TestReactStream(t *testing.T) { ctx := context.Background() fakeTool := &fakeToolGreetForTest{ tarCount: 20, } fakeStreamTool := &fakeStreamToolGreetForTest{ tarCount: 20, } ctrl := gomock.NewController(t) cm := mockModel.NewMockChatModel(ctrl) times := 0 cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) ( *schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](1) defer sw.Close() info, _ := fakeTool.Info(ctx) streamInfo, _ := fakeStreamTool.Info(ctx) times++ if times <= 2 { sw.Send(schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStr()), }, }, }), nil) return sr, nil } else if times == 3 { sw.Send(schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: streamInfo.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStr()), }, }, }), nil) return sr, nil } else if times == 4 { // parallel tool call sw.Send(schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStr()), }, }, { ID: randStr(), Function: schema.FunctionCall{ Name: streamInfo.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStr()), }, }, }), nil) return sr, nil } sw.Send(schema.AssistantMessage("bye", nil), nil) return sr, nil }).AnyTimes() a, err := NewAgent(ctx, &AgentConfig{ Model: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, }, MaxStep: 40, }) assert.Nil(t, err) out, err := a.Stream(ctx, []*schema.Message{ { Role: schema.User, Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) if err != nil { t.Fatal(err) } defer out.Close() msgs := make([]*schema.Message, 0) for { msg, err := out.Recv() if err != nil { if errors.Is(err, io.EOF) { break } t.Fatal(err) } msgs = append(msgs, msg) } assert.Equal(t, 1, len(msgs)) msg, err := schema.ConcatMessages(msgs) if err != nil { t.Fatal(err) } t.Log(msg.Content) info, err := fakeStreamTool.Info(ctx) assert.NoError(t, err) // test return directly a, err = NewAgent(ctx, &AgentConfig{ Model: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, }, MaxStep: 40, ToolReturnDirectly: map[string]struct{}{info.Name: {}}, // one of the two tools is return directly }) assert.Nil(t, err) times = 0 out, err = a.Stream(ctx, []*schema.Message{ { Role: schema.User, Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) if err != nil { t.Fatal(err) } defer out.Close() msgs = make([]*schema.Message, 0) for { msg, err := out.Recv() if err != nil { if errors.Is(err, io.EOF) { break } t.Fatal(err) } msgs = append(msgs, msg) } assert.Equal(t, 1, len(msgs)) msg, err = schema.ConcatMessages(msgs) if err != nil { t.Fatal(err) } t.Log(msg.Content) // return directly tool call within parallel tool calls out, err = a.Stream(ctx, []*schema.Message{ { Role: schema.User, Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) assert.NoError(t, err) defer out.Close() msgs = make([]*schema.Message, 0) for { msg, err := out.Recv() if err != nil { if errors.Is(err, io.EOF) { break } assert.NoError(t, err) } msgs = append(msgs, msg) } assert.Equal(t, 1, len(msgs)) msg, err = schema.ConcatMessages(msgs) assert.NoError(t, err) t.Log("parallel tool call with return directly: ", msg.Content) } func TestReactWithModifier(t *testing.T) { ctx := context.Background() fakeTool := &fakeToolGreetForTest{} ctrl := gomock.NewController(t) cm := mockModel.NewMockChatModel(ctrl) times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { times++ if times <= 2 { info, _ := fakeTool.Info(ctx) return schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), }, }, }), nil } return schema.AssistantMessage("bye", nil), nil }).AnyTimes() cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() a, err := NewAgent(ctx, &AgentConfig{ Model: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool}, }, MessageModifier: func(ctx context.Context, input []*schema.Message) []*schema.Message { res := make([]*schema.Message, 0, len(input)+1) res = append(res, schema.SystemMessage("you are a helpful assistant")) res = append(res, input...) return res }, MaxStep: 40, }) assert.Nil(t, err) out, err := a.Generate(ctx, []*schema.Message{ { Role: schema.User, Content: "hello", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) if err != nil { t.Fatal(err) } if out != nil { t.Log(out.Content) } } func TestAgentInGraph(t *testing.T) { t.Run("agent generate in chain", func(t *testing.T) { ctx := context.Background() fakeTool := &fakeToolGreetForTest{} ctrl := gomock.NewController(t) cm := mockModel.NewMockChatModel(ctrl) times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { times += 1 if times <= 2 { info, _ := fakeTool.Info(ctx) return schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), }, }, }), nil } return schema.AssistantMessage("bye", nil), nil }).Times(3) cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() a, err := NewAgent(ctx, &AgentConfig{ Model: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{fakeTool, &fakeStreamToolGreetForTest{}}, }, MaxStep: 40, }) assert.Nil(t, err) chain := compose.NewChain[[]*schema.Message, string]() agentLambda, err := compose.AnyLambda(a.Generate, a.Stream, nil, nil) assert.Nil(t, err) chain. AppendLambda(agentLambda). AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (string, error) { t.Log("got agent response: ", input.Content) return input.Content, nil })) r, err := chain.Compile(ctx) assert.Nil(t, err) res, err := r.Invoke(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}}, compose.WithCallbacks(callbackForTest)) assert.Nil(t, err) t.Log(res) }) t.Run("agent stream in chain", func(t *testing.T) { fakeStreamTool := &fakeStreamToolGreetForTest{} ctx := context.Background() ctrl := gomock.NewController(t) cm := mockModel.NewMockChatModel(ctrl) times := 0 cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) ( *schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](1) defer sw.Close() times += 1 if times <= 2 { info, _ := fakeStreamTool.Info(ctx) sw.Send(schema.AssistantMessage("hello max", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), }, }, }), nil) return sr, nil } sw.Send(schema.AssistantMessage("bye", nil), nil) return sr, nil }).Times(3) cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() a, err := NewAgent(ctx, &AgentConfig{ Model: cm, ToolsConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{&fakeToolGreetForTest{}, fakeStreamTool}, }, MaxStep: 40, }) assert.Nil(t, err) chain := compose.NewChain[[]*schema.Message, string]() agentGraph, opts := a.ExportGraph() assert.Nil(t, err) chain. AppendGraph(agentGraph, opts...). AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (string, error) { t.Log("got agent response: ", input.Content) return input.Content, nil })) r, err := chain.Compile(ctx) assert.Nil(t, err) outStream, err := r.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}}, compose.WithCallbacks(callbackForTest)) if err != nil { t.Fatal(err) } defer outStream.Close() msg := "" for { msgItem, err := outStream.Recv() if err != nil { if errors.Is(err, io.EOF) { break } t.Fatal(err) } msg += msgItem } t.Log(msg) }) } func TestWithTools(t *testing.T) { ctx := context.Background() fakeTool := &fakeToolGreetForTest{ tarCount: 2, } fakeStreamTool := &fakeStreamToolGreetForTest{ tarCount: 2, } ctrl := gomock.NewController(t) cm := mockModel.NewMockToolCallingChatModel(ctrl) times := 0 cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { times++ if times <= 1 { info, _ := fakeTool.Info(ctx) return schema.AssistantMessage("calling tool", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"name": "test"}`, }, }, }), nil } return schema.AssistantMessage("done", nil), nil }).AnyTimes() cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) ( *schema.StreamReader[*schema.Message], error) { sr, sw := schema.Pipe[*schema.Message](1) defer sw.Close() times++ if times <= 2 { info, _ := fakeStreamTool.Info(ctx) sw.Send(schema.AssistantMessage("calling stream tool", []schema.ToolCall{ { ID: randStr(), Function: schema.FunctionCall{ Name: info.Name, Arguments: `{"name": "test"}`, }, }, }), nil) return sr, nil } sw.Send(schema.AssistantMessage("stream done", nil), nil) return sr, nil }).AnyTimes() // Test WithTools function toolOptions, err := WithTools(ctx, fakeTool, fakeStreamTool) assert.NoError(t, err) assert.Len(t, toolOptions, 2, "WithTools should return exactly 2 options") // Create agent without tools in config a, err := NewAgent(ctx, &AgentConfig{ ToolCallingModel: cm, MaxStep: 10, }) assert.NoError(t, err) // Test Generate with WithTools options times = 0 msg, err := a.Generate(ctx, []*schema.Message{ schema.UserMessage("test generate with tools"), }, toolOptions...) assert.NoError(t, err) assert.Equal(t, "done", msg.Content) // Test Stream with WithTools options times = 0 stream, err := a.Stream(ctx, []*schema.Message{ schema.UserMessage("test stream with tools"), }, toolOptions...) assert.NoError(t, err) defer stream.Close() msgs := make([]*schema.Message, 0) for { msg, err := stream.Recv() if err != nil { if errors.Is(err, io.EOF) { break } assert.NoError(t, err) } msgs = append(msgs, msg) } assert.Len(t, msgs, 1) concatMsg, err := schema.ConcatMessages(msgs) assert.NoError(t, err) assert.Equal(t, "stream done", concatMsg.Content) // Test error case - tool Info() returns error errorTool := &errorToolForTest{} _, err = WithTools(ctx, errorTool) assert.Error(t, err) assert.Contains(t, err.Error(), "info error") } // Helper tool for testing error cases type errorToolForTest struct{} func (t *errorToolForTest) Info(_ context.Context) (*schema.ToolInfo, error) { return nil, errors.New("info error") } func (t *errorToolForTest) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { return "", nil } type fakeStreamToolGreetForTest struct { tarCount int curCount int } func (t *fakeStreamToolGreetForTest) StreamableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) ( *schema.StreamReader[string], error) { p := &fakeToolInput{} err := sonic.UnmarshalString(argumentsInJSON, p) if err != nil { return nil, err } if t.curCount >= t.tarCount { s := schema.StreamReaderFromArray([]string{`{"say": "bye"}`}) return s, nil } t.curCount++ s := schema.StreamReaderFromArray([]string{fmt.Sprintf(`{"say": "hello %v"}`, p.Name)}) return s, nil } type fakeToolGreetForTest struct { tarCount int curCount int } func (t *fakeToolGreetForTest) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "greet", Desc: "greet with name", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Desc: "user name who to greet", Required: true, Type: schema.String, }, }), }, nil } func (t *fakeStreamToolGreetForTest) Info(_ context.Context) (*schema.ToolInfo, error) { return &schema.ToolInfo{ Name: "greet in stream", Desc: "greet with name in stream", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Desc: "user name who to greet", Required: true, Type: schema.String, }, }), }, nil } func (t *fakeToolGreetForTest) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { p := &fakeToolInput{} err := sonic.UnmarshalString(argumentsInJSON, p) if err != nil { return "", err } if t.curCount >= t.tarCount { return `{"say": "bye"}`, nil } t.curCount++ return fmt.Sprintf(`{"say": "hello %v"}`, p.Name), nil } type fakeToolInput struct { Name string `json:"name"` } func randStr() string { seeds := []rune("this is a seed") b := make([]rune, 8) for i := range b { b[i] = seeds[rand.Intn(len(seeds))] } return string(b) } var callbackForTest = BuildAgentCallback(&template.ModelCallbackHandler{}, &template.ToolCallbackHandler{}) ================================================ FILE: flow/agent/utils.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package agent import ( "errors" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" ) // ChatModelWithTools returns a chat model configured with tool schemas. // If a ToolCallingChatModel is provided, it is used directly (and optionally // configured with tools). Otherwise, a plain ChatModel is bound with tools. func ChatModelWithTools(cm model.ChatModel, toolCallingModel model.ToolCallingChatModel, toolInfos []*schema.ToolInfo) ( model.BaseChatModel, error) { if toolCallingModel != nil { if len(toolInfos) == 0 { return toolCallingModel, nil } return toolCallingModel.WithTools(toolInfos) } if cm != nil { if len(toolInfos) == 0 { return cm, nil } err := cm.BindTools(toolInfos) if err != nil { return nil, err } return cm, nil } return nil, errors.New("no chat model provided") } ================================================ FILE: flow/indexer/parent/parent.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package parent provides an indexer that assigns stable IDs to sub-documents // and preserves relationships to their original parent document. package parent import ( "context" "fmt" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" ) // Config configures the parent indexer that assigns IDs to sub-documents. type Config struct { // Indexer is the underlying indexer implementation that handles the actual document indexing. // For example: a vector database indexer like Milvus, or a full-text search indexer like Elasticsearch. Indexer indexer.Indexer // Transformer processes documents before indexing, typically splitting them into smaller chunks. // Each sub-document generated by the transformer must retain its parent document's ID. // For example: if a document with ID "doc_1" is split into 3 chunks, all chunks will initially // have ID "doc_1". These IDs will later be modified by the SubIDGenerator. // // Example transformations: // - A text splitter that breaks down large documents into paragraphs // - A code splitter that separates code files into functions Transformer document.Transformer // ParentIDKey specifies the metadata key used to store the original document's ID in each sub-document. // For example: if ParentIDKey is "parent_id", each sub-document will have metadata like: // {"parent_id": "original_doc_123"} ParentIDKey string // SubIDGenerator generates unique IDs for sub-documents based on their parent document ID. // For example: if parent ID is "doc_1" and we need 3 sub-document IDs, it might generate: // ["doc_1_chunk_1", "doc_1_chunk_2", "doc_1_chunk_3"] // // Parameters: // - ctx: context for the operation // - parentID: the ID of the parent document // - num: number of sub-document IDs needed // Returns: // - []string: slice of generated sub-document IDs // - error: any error encountered during ID generation SubIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error) } // NewIndexer creates a new parent indexer that handles document splitting and sub-document management. // // Parameters: // - ctx: context for the operation // - config: configuration for the parent indexer // // Example usage: // // indexer, err := NewIndexer(ctx, &Config{ // Indexer: milvusIndexer, // Transformer: textSplitter, // ParentIDKey: "source_doc_id", // SubIDGenerator: func(ctx context.Context, parentID string, num int) ([]string, error) { // ids := make([]string, num) // for i := 0; i < num; i++ { // ids[i] = fmt.Sprintf("%s_chunk_%d", parentID, i+1) // } // return ids, nil // }, // }) // // Returns: // - indexer.Indexer: the created parent indexer // - error: any error encountered during creation func NewIndexer(ctx context.Context, config *Config) (indexer.Indexer, error) { if config.Indexer == nil { return nil, fmt.Errorf("indexer is empty") } if config.Transformer == nil { return nil, fmt.Errorf("transformer is empty") } if config.SubIDGenerator == nil { return nil, fmt.Errorf("sub id generator is empty") } return &parentIndexer{ indexer: config.Indexer, transformer: config.Transformer, parentIDKey: config.ParentIDKey, subIDGenerator: config.SubIDGenerator, }, nil } type parentIndexer struct { indexer indexer.Indexer transformer document.Transformer parentIDKey string subIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error) } func (p *parentIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { subDocs, err := p.transformer.Transform(ctx, docs) if err != nil { return nil, fmt.Errorf("transform docs fail: %w", err) } if len(subDocs) == 0 { return nil, fmt.Errorf("doc transformer returned no documents") } currentID := subDocs[0].ID startIdx := 0 for i, subDoc := range subDocs { if subDoc.MetaData == nil { subDoc.MetaData = make(map[string]any) } subDoc.MetaData[p.parentIDKey] = subDoc.ID if subDoc.ID == currentID { continue } // generate new doc id subIDs, err_ := p.subIDGenerator(ctx, subDocs[startIdx].ID, i-startIdx) if err_ != nil { return nil, err_ } if len(subIDs) != i-startIdx { return nil, fmt.Errorf("generated sub IDs' num is unexpected") } for j := startIdx; j < i; j++ { subDocs[j].ID = subIDs[j-startIdx] } startIdx = i currentID = subDoc.ID } // generate new doc id subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, len(subDocs)-startIdx) if err != nil { return nil, err } if len(subIDs) != len(subDocs)-startIdx { return nil, fmt.Errorf("generated sub IDs' num is unexpected") } for j := startIdx; j < len(subDocs); j++ { subDocs[j].ID = subIDs[j-startIdx] } return p.indexer.Store(ctx, subDocs, opts...) } ================================================ FILE: flow/indexer/parent/parent_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parent import ( "context" "fmt" "reflect" "strconv" "strings" "testing" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" ) type testIndexer struct{} func (t *testIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { ret := make([]string, len(docs)) for i, d := range docs { ret[i] = d.ID if !strings.HasPrefix(d.ID, d.MetaData["parent"].(string)) { return nil, fmt.Errorf("invalid parent key") } } return ret, nil } type testTransformer struct { } func (t *testTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) { var ret []*schema.Document for _, d := range src { ret = append(ret, &schema.Document{ ID: d.ID, Content: d.Content[:len(d.Content)/2], MetaData: deepCopyMap(d.MetaData), }, &schema.Document{ ID: d.ID, Content: d.Content[len(d.Content)/2:], MetaData: deepCopyMap(d.MetaData), }) } return ret, nil } func TestParentIndexer(t *testing.T) { tests := []struct { name string config *Config input []*schema.Document want []string }{ { name: "success", config: &Config{ Indexer: &testIndexer{}, Transformer: &testTransformer{}, ParentIDKey: "parent", SubIDGenerator: func(ctx context.Context, parentID string, num int) ([]string, error) { ret := make([]string, num) for i := range ret { ret[i] = parentID + strconv.Itoa(i) } return ret, nil }, }, input: []*schema.Document{{ ID: "id", Content: "1234567890", MetaData: map[string]interface{}{}, }, { ID: "ID", Content: "0987654321", MetaData: map[string]interface{}{}, }}, want: []string{"id0", "id1", "ID0", "ID1"}, }, } ctx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { index, err := NewIndexer(ctx, tt.config) if err != nil { t.Fatal(err) } ret, err := index.Store(ctx, tt.input) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(ret, tt.want) { t.Errorf("NewHeaderSplitter() got = %v, want %v", ret, tt.want) } }) } } func deepCopyMap(in map[string]interface{}) map[string]interface{} { out := make(map[string]interface{}) for k, v := range in { out[k] = v } return out } ================================================ FILE: flow/retriever/multiquery/multi_query.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package multiquery implements a query-rewriting retriever that expands // user queries into multiple variants to improve recall. package multiquery import ( "context" "fmt" "strings" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/retriever/utils" "github.com/cloudwego/eino/schema" ) const ( defaultRewritePrompt = `You are an helpful assistant. Your role is to create three different versions of the user query to retrieve relevant documents from store. Your goal is to improve the performance of similarity search by generating text from different perspectives based on the user query. Only provide the generated queries and separate them by newlines. user query: {{query}}` defaultQueryVariable = "query" defaultMaxQueriesNum = 5 ) var deduplicateFusion = func(ctx context.Context, docs [][]*schema.Document) ([]*schema.Document, error) { m := map[string]bool{} var ret []*schema.Document for i := range docs { for j := range docs[i] { if _, ok := m[docs[i][j].ID]; !ok { m[docs[i][j].ID] = true ret = append(ret, docs[i][j]) } } } return ret, nil } // NewRetriever creates a multi-query retriever. // multi-query retriever is useful when you want to retrieve documents from multiple retrievers with different queries. // e.g. // // multiRetriever := multiquery.NewRetriever(ctx, &multiquery.Config{}) // docs, err := multiRetriever.Retrieve(ctx, "how to build agent with eino") // if err != nil { // ... // } // println(docs) func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) { var err error // config validate if config.OrigRetriever == nil { return nil, fmt.Errorf("OrigRetriever is required") } if config.RewriteHandler == nil && config.RewriteLLM == nil { return nil, fmt.Errorf("at least one of RewriteHandler and RewriteLLM must not be empty") } // construct rewrite chain rewriteChain := compose.NewChain[string, []string]() if config.RewriteHandler != nil { rewriteChain.AppendLambda(compose.InvokableLambda(config.RewriteHandler), compose.WithNodeName("CustomQueryRewriter")) } else { tpl := config.RewriteTemplate variable := config.QueryVar parser := config.LLMOutputParser if tpl == nil { tpl = prompt.FromMessages(schema.Jinja2, schema.UserMessage(defaultRewritePrompt)) variable = defaultQueryVariable } if parser == nil { parser = func(ctx context.Context, message *schema.Message) ([]string, error) { return strings.Split(message.Content, "\n"), nil } } rewriteChain. AppendLambda(compose.InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { return map[string]any{variable: input}, nil }), compose.WithNodeName("Converter")). AppendChatTemplate(tpl). AppendChatModel(config.RewriteLLM). AppendLambda(compose.InvokableLambda(parser), compose.WithNodeName("OutputParser")) } rewriteRunner, err := rewriteChain.Compile(ctx, compose.WithGraphName("QueryRewrite")) if err != nil { return nil, err } maxQueriesNum := config.MaxQueriesNum if maxQueriesNum == 0 { maxQueriesNum = defaultMaxQueriesNum } fusionFunc := config.FusionFunc if fusionFunc == nil { fusionFunc = deduplicateFusion } return &multiQueryRetriever{ queryRunner: rewriteRunner, maxQueriesNum: maxQueriesNum, origRetriever: config.OrigRetriever, fusionFunc: fusionFunc, }, nil } // Config is the config for multi-query retriever. type Config struct { // Rewrite // 1. set the following fields to use llm to generate multi queries // a. chat model, required RewriteLLM model.ChatModel // b. prompt llm to generate multi queries, we provide default template so you can leave this field blank RewriteTemplate prompt.ChatTemplate // c. origin query variable of your custom template, it can be empty if you use default template QueryVar string // d. parser llm output to queries, split content using "\n" by default LLMOutputParser func(context.Context, *schema.Message) ([]string, error) // 2. set RewriteHandler to provide custom query generation logic, possibly without a ChatModel. If this field is set, it takes precedence over other configurations above RewriteHandler func(ctx context.Context, query string) ([]string, error) // limit max queries num that Rewrite generates, and excess queries will be truncated, 5 by default MaxQueriesNum int // Origin Retriever OrigRetriever retriever.Retriever // fusion docs recalled from multi retrievers, remove dup based on document id by default FusionFunc func(ctx context.Context, docs [][]*schema.Document) ([]*schema.Document, error) } type multiQueryRetriever struct { queryRunner compose.Runnable[string, []string] maxQueriesNum int origRetriever retriever.Retriever fusionFunc func(ctx context.Context, docs [][]*schema.Document) ([]*schema.Document, error) } // Retrieve retrieves documents from the multi-query retriever. func (m *multiQueryRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { // generate queries queries, err := m.queryRunner.Invoke(ctx, query) if err != nil { return nil, err } if len(queries) > m.maxQueriesNum { queries = queries[:m.maxQueriesNum] } // retrieve tasks := make([]*utils.RetrieveTask, len(queries)) for i := range queries { tasks[i] = &utils.RetrieveTask{Retriever: m.origRetriever, Query: queries[i]} } utils.ConcurrentRetrieveWithCallback(ctx, tasks) result := make([][]*schema.Document, len(queries)) for i, task := range tasks { if task.Err != nil { return nil, task.Err } result[i] = task.Result } // fusion ctx = ctxWithFusionRunInfo(ctx) ctx = callbacks.OnStart(ctx, result) fusionDocs, err := m.fusionFunc(ctx, result) if err != nil { callbacks.OnError(ctx, err) return nil, err } callbacks.OnEnd(ctx, fusionDocs) return fusionDocs, nil } // GetType returns the type of the retriever (MultiQuery). func (m *multiQueryRetriever) GetType() string { return "MultiQuery" } func ctxWithFusionRunInfo(ctx context.Context) context.Context { runInfo := &callbacks.RunInfo{ Component: compose.ComponentOfLambda, Type: "FusionFunc", } runInfo.Name = runInfo.Type + string(runInfo.Component) return callbacks.ReuseHandlers(ctx, runInfo) } ================================================ FILE: flow/retriever/multiquery/multi_query_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package multiquery import ( "context" "strings" "testing" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) type mockRetriever struct { } func (m *mockRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { var ret []*schema.Document if strings.Contains(query, "1") { ret = append(ret, &schema.Document{ID: "1"}) } if strings.Contains(query, "2") { ret = append(ret, &schema.Document{ID: "2"}) } if strings.Contains(query, "3") { ret = append(ret, &schema.Document{ID: "3"}) } if strings.Contains(query, "4") { ret = append(ret, &schema.Document{ID: "4"}) } if strings.Contains(query, "5") { ret = append(ret, &schema.Document{ID: "5"}) } return ret, nil } type mockModel struct { } func (m *mockModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { return &schema.Message{ Content: "12\n23\n34\n14\n23\n45", }, nil } func (m *mockModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { panic("implement me") } func (m *mockModel) BindTools(tools []*schema.ToolInfo) error { panic("implement me") } func TestMultiQueryRetriever(t *testing.T) { ctx := context.Background() // use default llm mqr, err := NewRetriever(ctx, &Config{ RewriteLLM: &mockModel{}, OrigRetriever: &mockRetriever{}, }) if err != nil { t.Fatal(err) } c := compose.NewChain[string, []*schema.Document]() cr, err := c.AppendRetriever(mqr).Compile(ctx) if err != nil { t.Fatal(err) } result, err := cr.Invoke(ctx, "query") if err != nil { t.Fatal(err) } if len(result) != 4 { t.Fatal("default llm retrieve result is unexpected") } // use custom mqr, err = NewRetriever(ctx, &Config{ RewriteHandler: func(ctx context.Context, query string) ([]string, error) { return []string{"1", "3", "5"}, nil }, OrigRetriever: &mockRetriever{}, }) if err != nil { t.Fatal(err) } c = compose.NewChain[string, []*schema.Document]() cr, err = c.AppendRetriever(mqr).Compile(ctx) if err != nil { t.Fatal(err) } result, err = cr.Invoke(ctx, "query") if err != nil { t.Fatal(err) } if len(result) != 3 { t.Fatal("default llm retrieve result is unexpected") } } ================================================ FILE: flow/retriever/parent/doc.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package parent provides a retriever that maps sub-document results // back to their original parent documents. package parent ================================================ FILE: flow/retriever/parent/parent.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parent import ( "context" "fmt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" ) // Config configures the parent retriever. type Config struct { // Retriever specifies the original retriever used to retrieve documents. // For example: a vector database retriever like Milvus, or a full-text search retriever like Elasticsearch. Retriever retriever.Retriever // ParentIDKey specifies the key used in the sub-document metadata to store the parent document ID. // Documents without this key will be removed from the recall results. // For example: if ParentIDKey is "parent_id", it will look for metadata like: // {"parent_id": "original_doc_123"} ParentIDKey string // OrigDocGetter specifies the method for getting original documents by ids from the sub-document metadata. // Parameters: // - ctx: context for the operation // - ids: slice of parent document IDs to retrieve // Returns: // - []*schema.Document: slice of retrieved parent documents // - error: any error encountered during retrieval // // For example: if sub-documents with parent IDs ["doc_1", "doc_2"] are retrieved, // OrigDocGetter will be called to fetch the original documents with these IDs. OrigDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error) } // NewRetriever creates a new parent retriever that handles retrieving original documents // based on sub-document search results. // // Parameters: // - ctx: context for the operation // - config: configuration for the parent retriever // // Example usage: // // retriever, err := NewRetriever(ctx, &Config{ // Retriever: milvusRetriever, // ParentIDKey: "source_doc_id", // OrigDocGetter: func(ctx context.Context, ids []string) ([]*schema.Document, error) { // return documentStore.GetByIDs(ctx, ids) // }, // }) // // Returns: // - retriever.Retriever: the created parent retriever // - error: any error encountered during creation func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) { if config.Retriever == nil { return nil, fmt.Errorf("retriever is required") } if config.OrigDocGetter == nil { return nil, fmt.Errorf("orig doc getter is required") } return &parentRetriever{ retriever: config.Retriever, parentIDKey: config.ParentIDKey, origDocGetter: config.OrigDocGetter, }, nil } type parentRetriever struct { retriever retriever.Retriever parentIDKey string origDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error) } func (p *parentRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { subDocs, err := p.retriever.Retrieve(ctx, query, opts...) if err != nil { return nil, err } ids := make([]string, 0, len(subDocs)) for _, subDoc := range subDocs { if k, ok := subDoc.MetaData[p.parentIDKey]; ok { if s, okk := k.(string); okk && !inList(s, ids) { ids = append(ids, s) } } } return p.origDocGetter(ctx, ids) } func inList(elem string, list []string) bool { for _, v := range list { if v == elem { return true } } return false } ================================================ FILE: flow/retriever/parent/parent_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package parent import ( "context" "reflect" "testing" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" ) type testRetriever struct{} func (t *testRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { ret := make([]*schema.Document, 0) for i := range query { ret = append(ret, &schema.Document{ ID: "", Content: "", MetaData: map[string]interface{}{ "parent": query[i : i+1], }, }) } return ret, nil } func TestParentRetriever(t *testing.T) { tests := []struct { name string config *Config input string want []*schema.Document }{ { name: "success", config: &Config{ Retriever: &testRetriever{}, ParentIDKey: "parent", OrigDocGetter: func(ctx context.Context, ids []string) ([]*schema.Document, error) { var ret []*schema.Document for i := range ids { ret = append(ret, &schema.Document{ID: ids[i]}) } return ret, nil }, }, input: "123233", want: []*schema.Document{ {ID: "1"}, {ID: "2"}, {ID: "3"}, }, }, } ctx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r, err := NewRetriever(ctx, tt.config) if err != nil { t.Fatal(err) } ret, err := r.Retrieve(ctx, tt.input) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(ret, tt.want) { t.Errorf("got %v, want %v", ret, tt.want) } }) } } ================================================ FILE: flow/retriever/router/router.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package router provides retrieval routing helpers that merge results // from multiple retrievers and apply ranking strategies. package router import ( "context" "fmt" "sort" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/retriever/utils" "github.com/cloudwego/eino/schema" ) var rrf = func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) { if len(result) < 1 { return nil, fmt.Errorf("no documents") } if len(result) == 1 { for _, docs := range result { return docs, nil } } docRankMap := make(map[string]float64) docMap := make(map[string]*schema.Document) for _, v := range result { for i := range v { docMap[v[i].ID] = v[i] if _, ok := docRankMap[v[i].ID]; !ok { docRankMap[v[i].ID] = 1.0 / float64(i+60) } else { docRankMap[v[i].ID] += 1.0 / float64(i+60) } } } docList := make([]*schema.Document, 0, len(docMap)) for id := range docMap { docList = append(docList, docMap[id]) } sort.Slice(docList, func(i, j int) bool { return docRankMap[docList[i].ID] > docRankMap[docList[j].ID] }) return docList, nil } // NewRetriever creates a router retriever. // router retriever is useful when you want to retrieve documents from multiple retrievers with different queries. // eg. // // routerRetriever := router.NewRetriever(ctx, &router.Config{}) // docs, err := routerRetriever.Retrieve(ctx, "how to build agent with eino") // if err != nil { // ... // } // println(docs) func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) { if len(config.Retrievers) == 0 { return nil, fmt.Errorf("retrievers is empty") } router := config.Router if router == nil { var retrieverSet []string for k := range config.Retrievers { retrieverSet = append(retrieverSet, k) } router = func(ctx context.Context, query string) ([]string, error) { return retrieverSet, nil } } fusion := config.FusionFunc if fusion == nil { fusion = rrf } return &routerRetriever{ retrievers: config.Retrievers, router: config.Router, fusionFunc: fusion, }, nil } // Config is the config for router retriever. type Config struct { // Retrievers is the retrievers to be used. Retrievers map[string]retriever.Retriever // Router is the function to route the query to the retrievers. Router func(ctx context.Context, query string) ([]string, error) // FusionFunc is the function to fuse the documents from the retrievers. FusionFunc func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) } type routerRetriever struct { retrievers map[string]retriever.Retriever router func(ctx context.Context, query string) ([]string, error) fusionFunc func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) } // Retrieve retrieves documents from the router retriever. func (e *routerRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { routeCtx := ctxWithRouterRunInfo(ctx) routeCtx = callbacks.OnStart(routeCtx, query) retrieverNames, err := e.router(routeCtx, query) if err != nil { callbacks.OnError(routeCtx, err) return nil, err } if len(retrieverNames) == 0 { err = fmt.Errorf("no retriever has been selected") callbacks.OnError(routeCtx, err) return nil, err } callbacks.OnEnd(routeCtx, retrieverNames) // retrieve tasks := make([]*utils.RetrieveTask, len(retrieverNames)) for i := range retrieverNames { r, ok := e.retrievers[retrieverNames[i]] if !ok { return nil, fmt.Errorf("router output[%s] has not registered", retrieverNames[i]) } tasks[i] = &utils.RetrieveTask{ Name: retrieverNames[i], Retriever: r, Query: query, RetrieveOptions: opts, } } utils.ConcurrentRetrieveWithCallback(ctx, tasks) result := make(map[string][]*schema.Document) for i := range tasks { if tasks[i].Err != nil { return nil, tasks[i].Err } result[tasks[i].Name] = tasks[i].Result } // fusion fusionCtx := ctxWithFusionRunInfo(ctx) fusionCtx = callbacks.OnStart(fusionCtx, result) fusionDocs, err := e.fusionFunc(fusionCtx, result) if err != nil { callbacks.OnError(fusionCtx, err) return nil, err } callbacks.OnEnd(fusionCtx, fusionDocs) return fusionDocs, nil } // GetType returns the type of the retriever (Router). func (e *routerRetriever) GetType() string { return "Router" } func ctxWithRouterRunInfo(ctx context.Context) context.Context { runInfo := &callbacks.RunInfo{ Component: compose.ComponentOfLambda, Type: "Router", } runInfo.Name = runInfo.Type + string(runInfo.Component) return callbacks.ReuseHandlers(ctx, runInfo) } func ctxWithFusionRunInfo(ctx context.Context) context.Context { runInfo := &callbacks.RunInfo{ Component: compose.ComponentOfLambda, Type: "FusionFunc", } runInfo.Name = runInfo.Type + string(runInfo.Component) return callbacks.ReuseHandlers(ctx, runInfo) } ================================================ FILE: flow/retriever/router/router_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package router import ( "context" "reflect" "strings" "testing" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" ) type mockRetriever struct { } func (m *mockRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { var ret []*schema.Document if strings.Contains(query, "1") { ret = append(ret, &schema.Document{ID: "1"}) } if strings.Contains(query, "2") { ret = append(ret, &schema.Document{ID: "2"}) } if strings.Contains(query, "3") { ret = append(ret, &schema.Document{ID: "3"}) } if strings.Contains(query, "4") { ret = append(ret, &schema.Document{ID: "4"}) } if strings.Contains(query, "5") { ret = append(ret, &schema.Document{ID: "5"}) } return ret, nil } func (m *mockRetriever) GetType() string { return "Mock" } func TestRouterRetriever(t *testing.T) { ctx := context.Background() r, err := NewRetriever(ctx, &Config{ Retrievers: map[string]retriever.Retriever{ "1": &mockRetriever{}, "2": &mockRetriever{}, "3": &mockRetriever{}, }, Router: func(ctx context.Context, query string) ([]string, error) { return []string{"2", "3"}, nil }, FusionFunc: func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) { var ret []*schema.Document for _, v := range result { ret = append(ret, v...) } return ret, nil }, }) if err != nil { t.Fatal(err) } handler := callbacks.NewHandlerBuilder(). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { switch info.Name { case "FusionFuncLambda": if _, ok := output.([]*schema.Document); !ok { t.Fatal("FusionFuncLambda output is not a []*schema.Document") } case "RouterLambda": if _, ok := output.([]string); !ok { t.Fatal("RouterLambda output is not a []string") } case "MockRetriever": if _, ok := output.([]*schema.Document); !ok { t.Fatal("MockRetriever output is not a []string") } default: t.Fatalf("unknown name: %s", info.Name) } return ctx }). OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { t.Fatal(err) return ctx }).Build() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{}, handler) result, err := r.Retrieve(ctx, "3") if err != nil { t.Fatal(err) } if len(result) != 2 { t.Fatal("expected 2 results") } } func TestRRF(t *testing.T) { doc1 := &schema.Document{ID: "1"} doc2 := &schema.Document{ID: "2"} doc3 := &schema.Document{ID: "3"} doc4 := &schema.Document{ID: "4"} doc5 := &schema.Document{ID: "5"} input := map[string][]*schema.Document{ "1": {doc1, doc2, doc3, doc4, doc5}, "2": {doc2, doc3, doc4, doc5, doc1}, "3": {doc3, doc4, doc5, doc1, doc2}, } result, err := rrf(context.Background(), input) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(result, []*schema.Document{doc3, doc2, doc4, doc1, doc5}) { t.Fatal("rrf fail") } } ================================================ FILE: flow/retriever/utils/utils.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package utils provides helper utilities for retriever flows, including // concurrent retrieval with callback instrumentation. package utils import ( "context" "fmt" "sync" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" ) // RetrieveTask is a task for retrieving documents. // RetrieveTask represents a single retrieval job with its result or error. type RetrieveTask struct { Name string Retriever retriever.Retriever Query string RetrieveOptions []retriever.Option Result []*schema.Document Err error } // ConcurrentRetrieveWithCallback concurrently retrieves documents with callback. func ConcurrentRetrieveWithCallback(ctx context.Context, tasks []*RetrieveTask) { wg := sync.WaitGroup{} for i := range tasks { wg.Add(1) go func(ctx context.Context, t *RetrieveTask) { ctx = ctxWithRetrieverRunInfo(ctx, t.Retriever) defer func() { if e := recover(); e != nil { t.Err = fmt.Errorf("retrieve panic, query: %s, error: %v", t.Query, e) ctx = callbacks.OnError(ctx, t.Err) } wg.Done() }() ctx = callbacks.OnStart(ctx, t.Query) docs, err := t.Retriever.Retrieve(ctx, t.Query, t.RetrieveOptions...) if err != nil { callbacks.OnError(ctx, err) t.Err = err return } callbacks.OnEnd(ctx, docs) t.Result = docs }(ctx, tasks[i]) } wg.Wait() } func ctxWithRetrieverRunInfo(ctx context.Context, r retriever.Retriever) context.Context { runInfo := &callbacks.RunInfo{ Component: components.ComponentOfRetriever, } if typ, okk := components.GetType(r); okk { runInfo.Type = typ } runInfo.Name = runInfo.Type + string(runInfo.Component) return callbacks.ReuseHandlers(ctx, runInfo) } ================================================ FILE: go.mod ================================================ module github.com/cloudwego/eino go 1.18 require ( github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/bytedance/sonic v1.15.0 github.com/eino-contrib/jsonschema v1.0.3 github.com/google/uuid v1.6.0 github.com/nikolalohinski/gonja v1.5.3 github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.10.0 github.com/wk8/go-ordered-map/v2 v2.1.8 go.uber.org/mock v0.4.0 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.9 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/smarty/assertions v1.15.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/sys v0.26.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) ================================================ FILE: go.sum ================================================ github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= github.com/bmatcuk/doublestar/v4 v4.10.0 h1:zU9WiOla1YA122oLM6i4EXvGW62DvKZVxIe6TYWexEs= github.com/bmatcuk/doublestar/v4 v4.10.0/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: internal/callbacks/inject.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "context" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/schema" ) func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) context.Context { mgr, ok := newManager(info, handlers...) if ok { return ctxWithManager(ctx, mgr) } return ctxWithManager(ctx, nil) } func EnsureRunInfo(ctx context.Context, typ string, comp components.Component) context.Context { cbm, ok := managerFromCtx(ctx) if !ok { return InitCallbacks(ctx, &RunInfo{ Type: typ, Component: comp, }) } if cbm.runInfo == nil { return ReuseHandlers(ctx, &RunInfo{ Type: typ, Component: comp, }) } return ctx } func ReuseHandlers(ctx context.Context, info *RunInfo) context.Context { cbm, ok := managerFromCtx(ctx) if !ok { return InitCallbacks(ctx, info) } return ctxWithManager(ctx, cbm.withRunInfo(info)) } func AppendHandlers(ctx context.Context, info *RunInfo, handlers ...Handler) context.Context { cbm, ok := managerFromCtx(ctx) if !ok { return InitCallbacks(ctx, info, handlers...) } nh := make([]Handler, len(cbm.handlers)+len(handlers)) copy(nh[:len(cbm.handlers)], cbm.handlers) copy(nh[len(cbm.handlers):], handlers) return InitCallbacks(ctx, info, nh...) } type Handle[T any] func(context.Context, T, *RunInfo, []Handler) (context.Context, T) func On[T any](ctx context.Context, inOut T, handle Handle[T], timing CallbackTiming, start bool) (context.Context, T) { mgr, ok := managerFromCtx(ctx) if !ok { return ctx, inOut } nMgr := *mgr var info *RunInfo if start { info = nMgr.runInfo nMgr.runInfo = nil ctx = context.WithValue(ctx, CtxRunInfoKey{}, info) } else { if nMgr.runInfo != nil { info = nMgr.runInfo } else { info, _ = ctx.Value(CtxRunInfoKey{}).(*RunInfo) } } hs := make([]Handler, 0, len(nMgr.handlers)+len(nMgr.globalHandlers)) for _, handler := range append(nMgr.handlers, nMgr.globalHandlers...) { timingChecker, ok_ := handler.(TimingChecker) if !ok_ || timingChecker.Needed(ctx, info, timing) { hs = append(hs, handler) } } var out T ctx, out = handle(ctx, inOut, info, hs) return ctxWithManager(ctx, &nMgr), out } func OnStartHandle[T any](ctx context.Context, input T, runInfo *RunInfo, handlers []Handler) (context.Context, T) { for i := len(handlers) - 1; i >= 0; i-- { ctx = handlers[i].OnStart(ctx, runInfo, input) } return ctx, input } func OnEndHandle[T any](ctx context.Context, output T, runInfo *RunInfo, handlers []Handler) (context.Context, T) { for _, handler := range handlers { ctx = handler.OnEnd(ctx, runInfo, output) } return ctx, output } func BuildOnEndHandleWithCopy[T any](copyFn func(T, int) []T) Handle[T] { return func(ctx context.Context, output T, runInfo *RunInfo, handlers []Handler) (context.Context, T) { if len(handlers) == 0 { return ctx, output } copies := copyFn(output, len(handlers)) for i, handler := range handlers { ctx = handler.OnEnd(ctx, runInfo, copies[i]) } return ctx, output } } func OnWithStreamHandle[S any]( ctx context.Context, inOut S, handlers []Handler, cpy func(int) []S, handle func(context.Context, Handler, S) context.Context) (context.Context, S) { if len(handlers) == 0 { return ctx, inOut } inOuts := cpy(len(handlers) + 1) for i, handler := range handlers { ctx = handle(ctx, handler, inOuts[i]) } return ctx, inOuts[len(inOuts)-1] } func OnStartWithStreamInputHandle[T any](ctx context.Context, input *schema.StreamReader[T], runInfo *RunInfo, handlers []Handler) (context.Context, *schema.StreamReader[T]) { handlers = generic.Reverse(handlers) cpy := input.Copy handle := func(ctx context.Context, handler Handler, in *schema.StreamReader[T]) context.Context { in_ := schema.StreamReaderWithConvert(in, func(i T) (CallbackInput, error) { return i, nil }) return handler.OnStartWithStreamInput(ctx, runInfo, in_) } return OnWithStreamHandle(ctx, input, handlers, cpy, handle) } func OnEndWithStreamOutputHandle[T any](ctx context.Context, output *schema.StreamReader[T], runInfo *RunInfo, handlers []Handler) (context.Context, *schema.StreamReader[T]) { cpy := output.Copy handle := func(ctx context.Context, handler Handler, out *schema.StreamReader[T]) context.Context { out_ := schema.StreamReaderWithConvert(out, func(i T) (CallbackOutput, error) { return i, nil }) return handler.OnEndWithStreamOutput(ctx, runInfo, out_) } return OnWithStreamHandle(ctx, output, handlers, cpy, handle) } func OnErrorHandle(ctx context.Context, err error, runInfo *RunInfo, handlers []Handler) (context.Context, error) { for _, handler := range handlers { ctx = handler.OnError(ctx, runInfo, err) } return ctx, err } ================================================ FILE: internal/callbacks/interface.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "context" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/schema" ) type RunInfo struct { // Name is the graph node name for display purposes, not unique. // Passed from compose.WithNodeName(). Name string Type string Component components.Component } type CallbackInput any type CallbackOutput any type Handler interface { OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context OnError(ctx context.Context, info *RunInfo, err error) context.Context OnStartWithStreamInput(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context OnEndWithStreamOutput(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context } type CallbackTiming uint8 type TimingChecker interface { Needed(ctx context.Context, info *RunInfo, timing CallbackTiming) bool } ================================================ FILE: internal/callbacks/manager.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import "context" type CtxManagerKey struct{} type CtxRunInfoKey struct{} type manager struct { globalHandlers []Handler handlers []Handler runInfo *RunInfo } var GlobalHandlers []Handler func newManager(runInfo *RunInfo, handlers ...Handler) (*manager, bool) { if len(handlers)+len(GlobalHandlers) == 0 { return nil, false } hs := make([]Handler, len(GlobalHandlers)) copy(hs, GlobalHandlers) return &manager{ globalHandlers: hs, handlers: handlers, runInfo: runInfo, }, true } func ctxWithManager(ctx context.Context, manager *manager) context.Context { return context.WithValue(ctx, CtxManagerKey{}, manager) } func (m *manager) withRunInfo(runInfo *RunInfo) *manager { if m == nil { return nil } n := *m n.runInfo = runInfo return &n } func managerFromCtx(ctx context.Context) (*manager, bool) { v := ctx.Value(CtxManagerKey{}) m, ok := v.(*manager) if ok && m != nil { n := *m return &n, true } return nil, false } ================================================ FILE: internal/channel.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import "sync" // UnboundedChan represents a channel with unlimited capacity type UnboundedChan[T any] struct { buffer []T // Internal buffer to store data mutex sync.Mutex // Mutex to protect buffer access notEmpty *sync.Cond // Condition variable to wait for data closed bool // Indicates if the channel has been closed } // NewUnboundedChan initializes and returns an UnboundedChan func NewUnboundedChan[T any]() *UnboundedChan[T] { ch := &UnboundedChan[T]{} ch.notEmpty = sync.NewCond(&ch.mutex) return ch } // Send puts an item into the channel func (ch *UnboundedChan[T]) Send(value T) { ch.mutex.Lock() defer ch.mutex.Unlock() if ch.closed { panic("send on closed channel") } ch.buffer = append(ch.buffer, value) ch.notEmpty.Signal() // Wake up one goroutine waiting to receive } // Receive gets an item from the channel (blocks if empty) func (ch *UnboundedChan[T]) Receive() (T, bool) { ch.mutex.Lock() defer ch.mutex.Unlock() for len(ch.buffer) == 0 && !ch.closed { ch.notEmpty.Wait() // Wait until data is available } if len(ch.buffer) == 0 { // Channel is closed and empty var zero T return zero, false } val := ch.buffer[0] ch.buffer = ch.buffer[1:] return val, true } // Close marks the channel as closed func (ch *UnboundedChan[T]) Close() { ch.mutex.Lock() defer ch.mutex.Unlock() if !ch.closed { ch.closed = true ch.notEmpty.Broadcast() // Wake up all waiting goroutines } } ================================================ FILE: internal/channel_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "sync" "testing" "time" ) func TestUnboundedChan_Send(t *testing.T) { ch := NewUnboundedChan[string]() // Test sending a value ch.Send("test") if len(ch.buffer) != 1 { t.Errorf("buffer length should be 1, got %d", len(ch.buffer)) } if ch.buffer[0] != "test" { t.Errorf("expected 'test', got '%s'", ch.buffer[0]) } // Test sending multiple values ch.Send("test2") ch.Send("test3") if len(ch.buffer) != 3 { t.Errorf("buffer length should be 3, got %d", len(ch.buffer)) } } func TestUnboundedChan_SendPanic(t *testing.T) { ch := NewUnboundedChan[int]() ch.Close() // Test sending to closed channel should panic defer func() { if r := recover(); r == nil { t.Error("sending to closed channel should panic") } }() ch.Send(1) } func TestUnboundedChan_Receive(t *testing.T) { ch := NewUnboundedChan[int]() // Send values ch.Send(1) ch.Send(2) // Test receiving values val, ok := ch.Receive() if !ok { t.Error("receive should succeed") } if val != 1 { t.Errorf("expected 1, got %d", val) } val, ok = ch.Receive() if !ok { t.Error("receive should succeed") } if val != 2 { t.Errorf("expected 2, got %d", val) } } func TestUnboundedChan_ReceiveFromClosed(t *testing.T) { ch := NewUnboundedChan[int]() ch.Close() // Test receiving from closed, empty channel val, ok := ch.Receive() if ok { t.Error("receive from closed, empty channel should return ok=false") } if val != 0 { t.Errorf("expected zero value, got %d", val) } // Test receiving from closed channel with values ch = NewUnboundedChan[int]() ch.Send(42) ch.Close() val, ok = ch.Receive() if !ok { t.Error("receive should succeed") } if val != 42 { t.Errorf("expected 42, got %d", val) } // After consuming all values val, ok = ch.Receive() if ok { t.Error("receive from closed, empty channel should return ok=false") } } func TestUnboundedChan_Close(t *testing.T) { ch := NewUnboundedChan[int]() // Test closing ch.Close() if !ch.closed { t.Error("channel should be marked as closed") } // Test double closing (should not panic) ch.Close() } func TestUnboundedChan_Concurrency(t *testing.T) { ch := NewUnboundedChan[int]() const numSenders = 5 const numReceivers = 3 const messagesPerSender = 100 var rwg, swg sync.WaitGroup rwg.Add(numReceivers) swg.Add(numSenders) // Start senders for i := 0; i < numSenders; i++ { go func(id int) { defer swg.Done() for j := 0; j < messagesPerSender; j++ { ch.Send(id*messagesPerSender + j) time.Sleep(time.Microsecond) // Small delay to increase concurrency chance } }(i) } // Start receivers received := make([]int, 0, numSenders*messagesPerSender) var mu sync.Mutex for i := 0; i < numReceivers; i++ { go func() { defer rwg.Done() for { val, ok := ch.Receive() if !ok { return } mu.Lock() received = append(received, val) mu.Unlock() } }() } // Wait for senders to finish swg.Wait() ch.Close() // Wait for all goroutines to finish rwg.Wait() // Verify we received all messages if len(received) != numSenders*messagesPerSender { t.Errorf("expected %d messages, got %d", numSenders*messagesPerSender, len(received)) } // Create a map to check for duplicates and missing values receivedMap := make(map[int]bool) for _, val := range received { receivedMap[val] = true } if len(receivedMap) != numSenders*messagesPerSender { t.Error("duplicate or missing messages detected") } } func TestUnboundedChan_BlockingReceive(t *testing.T) { ch := NewUnboundedChan[int]() // Test that Receive blocks when channel is empty receiveDone := make(chan bool) go func() { ch.Receive() receiveDone <- true }() // Check that receive is blocked select { case <-receiveDone: t.Error("Receive should block on empty channel") case <-time.After(50 * time.Millisecond): // This is expected } // Send a value to unblock ch.Send(1) // Now receive should complete select { case <-receiveDone: // This is expected case <-time.After(50 * time.Millisecond): t.Error("Receive should have unblocked") } } ================================================ FILE: internal/concat.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "fmt" "reflect" "strings" "time" "github.com/cloudwego/eino/internal/generic" ) var ( concatFuncs = map[reflect.Type]any{ generic.TypeOf[string](): concatStrings, generic.TypeOf[int8](): useLast[int8], generic.TypeOf[int16](): useLast[int16], generic.TypeOf[int32](): useLast[int32], generic.TypeOf[int64](): useLast[int64], generic.TypeOf[int](): useLast[int], generic.TypeOf[uint8](): useLast[uint8], generic.TypeOf[uint16](): useLast[uint16], generic.TypeOf[uint32](): useLast[uint32], generic.TypeOf[uint64](): useLast[uint64], generic.TypeOf[uint](): useLast[uint], generic.TypeOf[bool](): useLast[bool], generic.TypeOf[float32](): useLast[float32], generic.TypeOf[float64](): useLast[float64], generic.TypeOf[time.Time](): useLast[time.Time], generic.TypeOf[time.Duration](): useLast[time.Duration], } ) func useLast[T any](s []T) (T, error) { return s[len(s)-1], nil } func concatStrings(ss []string) (string, error) { var n int for _, s := range ss { n += len(s) } var b strings.Builder b.Grow(n) for _, s := range ss { _, err := b.WriteString(s) if err != nil { return "", err } } return b.String(), nil } func RegisterStreamChunkConcatFunc[T any](fn func([]T) (T, error)) { concatFuncs[generic.TypeOf[T]()] = fn } func GetConcatFunc(typ reflect.Type) func(reflect.Value) (reflect.Value, error) { if fn, ok := concatFuncs[typ]; ok { return func(a reflect.Value) (reflect.Value, error) { rvs := reflect.ValueOf(fn).Call([]reflect.Value{a}) var err error if !rvs[1].IsNil() { err = rvs[1].Interface().(error) } return rvs[0], err } } return nil } // ConcatItems the caller should ensure len(items) > 1 func ConcatItems[T any](items []T) (T, error) { typ := generic.TypeOf[T]() v := reflect.ValueOf(items) var cv reflect.Value var err error // handle map kind if typ.Kind() == reflect.Map { cv, err = concatMaps(v) } else { cv, err = concatSliceValue(v) } if err != nil { var t T return t, err } return cv.Interface().(T), nil } func concatMaps(ms reflect.Value) (reflect.Value, error) { typ := ms.Type().Elem() rms := reflect.MakeMap(reflect.MapOf(typ.Key(), generic.TypeOf[[]any]())) ret := reflect.MakeMap(typ) n := ms.Len() for i := 0; i < n; i++ { m := ms.Index(i) for _, key := range m.MapKeys() { vals := rms.MapIndex(key) if !vals.IsValid() { var s []any vals = reflect.ValueOf(s) } val := m.MapIndex(key) vals = reflect.Append(vals, val) rms.SetMapIndex(key, vals) } } for _, key := range rms.MapKeys() { vals := rms.MapIndex(key) anyVals := vals.Interface().([]any) if len(anyVals) == 1 { ele := anyVals[0] if ele == nil { // we cannot SetMapIndex with nil because it will delete the key ret.SetMapIndex(key, reflect.Zero(typ.Elem())) continue } ret.SetMapIndex(key, reflect.ValueOf(ele)) continue } v, err := toSliceValue(anyVals) if err != nil { return reflect.Value{}, err } var cv reflect.Value if v.Type().Elem().Kind() == reflect.Map { cv, err = concatMaps(v) } else { cv, err = concatSliceValue(v) } if err != nil { return reflect.Value{}, err } ret.SetMapIndex(key, cv) } return ret, nil } func concatSliceValue(val reflect.Value) (reflect.Value, error) { elmType := val.Type().Elem() if val.Len() == 1 { return val.Index(0), nil } f := GetConcatFunc(elmType) if f != nil { return f(val) } // if all elements in the slice are empty, return an empty value // if there is exactly one non-empty element in the slice, return that non-empty element // otherwise, throw an error. var filtered reflect.Value for i := 0; i < val.Len(); i++ { oneVal := val.Index(i) if !oneVal.IsZero() { if filtered.IsValid() { return reflect.Value{}, fmt.Errorf("cannot concat multiple non-zero value of type %s", elmType) } filtered = oneVal } } if !filtered.IsValid() { filtered = reflect.New(elmType).Elem() } return filtered, nil } func toSliceValue(vs []any) (reflect.Value, error) { typ := reflect.TypeOf(vs[0]) ret := reflect.MakeSlice(reflect.SliceOf(typ), len(vs), len(vs)) ret.Index(0).Set(reflect.ValueOf(vs[0])) for i := 1; i < len(vs); i++ { v := vs[i] vt := reflect.TypeOf(v) if typ != vt { return reflect.Value{}, fmt.Errorf("unexpected slice element type. Got %v, expected %v", typ, vt) } ret.Index(i).Set(reflect.ValueOf(v)) } return ret, nil } ================================================ FILE: internal/concat_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "testing" "github.com/stretchr/testify/assert" ) func TestConcat(t *testing.T) { t.Run("concat map chunks with nil value", func(t *testing.T) { c1 := map[string]any{ "a": map[string]any{ "b": map[string]any{ "c1": nil, }, }, } c2 := map[string]any{ "a": map[string]any{ "b": map[string]any{ "c2": "c2", }, }, } m, err := ConcatItems([]map[string]any{c1, c2}) assert.Nil(t, err) assert.Equal(t, map[string]any{ "a": map[string]any{ "b": map[string]any{ "c1": nil, "c2": "c2", }, }, }, m) }) } ================================================ FILE: internal/core/address.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package core import ( "context" "fmt" "strings" "sync" "github.com/cloudwego/eino/internal/generic" ) // AddressSegmentType defines the type of a segment in an execution address. type AddressSegmentType string // Address represents a full, hierarchical address to a point in the execution structure. type Address []AddressSegment // String converts an Address into its unique string representation. func (p Address) String() string { if p == nil { return "" } var sb strings.Builder for i, s := range p { sb.WriteString(string(s.Type)) sb.WriteString(":") sb.WriteString(s.ID) if s.SubID != "" { sb.WriteString(":") sb.WriteString(s.SubID) } if i != len(p)-1 { sb.WriteString(";") } } return sb.String() } func (p Address) Equals(other Address) bool { if len(p) != len(other) { return false } for i := range p { if p[i].Type != other[i].Type || p[i].ID != other[i].ID || p[i].SubID != other[i].SubID { return false } } return true } // AddressSegment represents a single segment in the hierarchical address of an execution point. // A sequence of AddressSegments uniquely identifies a location within a potentially nested structure. type AddressSegment struct { // ID is the unique identifier for this segment, e.g., the node's key or the tool's name. ID string // Type indicates whether this address segment is a graph node, a tool call, an agent, etc. Type AddressSegmentType // In some cases, ID alone are not unique enough, we need this SubID to guarantee uniqueness. // e.g. parallel tool calls with the same name but different tool call IDs. SubID string } type addrCtxKey struct{} type addrCtx struct { addr Address interruptState *InterruptState isResumeTarget bool resumeData any } type globalResumeInfoKey struct{} type globalResumeInfo struct { mu sync.Mutex id2ResumeData map[string]any id2ResumeDataUsed map[string]bool id2State map[string]InterruptState id2StateUsed map[string]bool id2Addr map[string]Address } // GetCurrentAddress returns the hierarchical address of the currently executing component. // The address is a sequence of segments, each identifying a structural part of the execution // like an agent, a graph node, or a tool call. This can be useful for logging or debugging. func GetCurrentAddress(ctx context.Context) Address { if p, ok := ctx.Value(addrCtxKey{}).(*addrCtx); ok { return p.addr } return nil } // AppendAddressSegment creates a new execution context for a sub-component (e.g., a graph node or a tool call). // // It extends the current context's address with a new segment and populates the new context with the // appropriate interrupt state and resume data for that specific sub-address. // // - ctx: The parent context, typically the one passed into the component's Invoke/Stream method. // - segType: The type of the new address segment (e.g., "node", "tool"). // - segID: The unique ID for the new address segment. func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID string, subID string) context.Context { // get current address currentAddress := GetCurrentAddress(ctx) if len(currentAddress) == 0 { currentAddress = []AddressSegment{ { Type: segType, ID: segID, SubID: subID, }, } } else { newAddress := make([]AddressSegment, len(currentAddress)+1) copy(newAddress, currentAddress) newAddress[len(newAddress)-1] = AddressSegment{ Type: segType, ID: segID, SubID: subID, } currentAddress = newAddress } runCtx := &addrCtx{ addr: currentAddress, } rInfo, hasRInfo := getResumeInfo(ctx) if !hasRInfo { return context.WithValue(ctx, addrCtxKey{}, runCtx) } var id string for id_, addr := range rInfo.id2Addr { if addr.Equals(currentAddress) { rInfo.mu.Lock() if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) rInfo.id2StateUsed[id_] = true id = id_ rInfo.mu.Unlock() break } rInfo.mu.Unlock() } } // take from globalResumeInfo the data for the new address if there is any rInfo.mu.Lock() defer rInfo.mu.Unlock() used := rInfo.id2ResumeDataUsed[id] if !used { rData, existed := rInfo.id2ResumeData[id] if existed { rInfo.id2ResumeDataUsed[id] = true runCtx.resumeData = rData runCtx.isResumeTarget = true } } // Also mark as resume target if any descendant address is a resume target. // This allows composite components (e.g., a tool containing a nested graph) to know // they should execute their children to reach the actual resume target. // We only consider descendants whose resume data has not yet been consumed. if !runCtx.isResumeTarget { for id_, addr := range rInfo.id2Addr { if len(addr) > len(currentAddress) && addr[:len(currentAddress)].Equals(currentAddress) { if !rInfo.id2ResumeDataUsed[id_] { runCtx.isResumeTarget = true break } } } } return context.WithValue(ctx, addrCtxKey{}, runCtx) } // GetNextResumptionPoints finds the immediate child resumption points for a given parent address. func GetNextResumptionPoints(ctx context.Context) (map[string]bool, error) { parentAddr := GetCurrentAddress(ctx) rInfo, exists := getResumeInfo(ctx) if !exists { return nil, fmt.Errorf("GetNextResumptionPoints: failed to get resume info from context") } nextPoints := make(map[string]bool) parentAddrLen := len(parentAddr) for _, addr := range rInfo.id2Addr { // Check if addr is a potential child (must be longer than parent) if len(addr) <= parentAddrLen { continue } // Check if it has the parent address as a prefix var isPrefix bool if parentAddrLen == 0 { isPrefix = true } else { isPrefix = addr[:parentAddrLen].Equals(parentAddr) } if !isPrefix { continue } // We are looking for immediate children. // The address of an immediate child should be one segment longer. childAddr := addr[parentAddrLen : parentAddrLen+1] childID := childAddr[0].ID // Avoid adding duplicates. if _, ok := nextPoints[childID]; !ok { nextPoints[childID] = true } } return nextPoints, nil } // BatchResumeWithData is the core function for preparing a resume context. It injects a map // of resume targets and their corresponding data into the context. // // The `resumeData` map should contain the interrupt IDs (which are the string form of addresses) of the // components to be resumed as keys. The value can be the resume data for that component, or `nil` // if no data is needed (equivalent to using `Resume`). // // This function is the foundation for the "Explicit Targeted Resume" strategy. Components whose interrupt IDs // are present as keys in the map will receive `isResumeFlow = true` when they call `GetResumeContext`. func BatchResumeWithData(ctx context.Context, resumeData map[string]any) context.Context { rInfo, ok := ctx.Value(globalResumeInfoKey{}).(*globalResumeInfo) if !ok { // Create a new globalResumeInfo and copy the map to prevent external mutation. newMap := make(map[string]any, len(resumeData)) for k, v := range resumeData { newMap[k] = v } return context.WithValue(ctx, globalResumeInfoKey{}, &globalResumeInfo{ id2ResumeData: newMap, id2ResumeDataUsed: make(map[string]bool), id2StateUsed: make(map[string]bool), }) } rInfo.mu.Lock() defer rInfo.mu.Unlock() if rInfo.id2ResumeData == nil { rInfo.id2ResumeData = make(map[string]any) } for id, data := range resumeData { rInfo.id2ResumeData[id] = data } return ctx } func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, id2State map[string]InterruptState) context.Context { rInfo, ok := ctx.Value(globalResumeInfoKey{}).(*globalResumeInfo) if ok { if rInfo.id2Addr == nil { rInfo.id2Addr = make(map[string]Address) } for id, addr := range id2Addr { rInfo.id2Addr[id] = addr } rInfo.id2State = id2State } else { rInfo = &globalResumeInfo{ id2Addr: id2Addr, id2State: id2State, id2StateUsed: make(map[string]bool), id2ResumeDataUsed: make(map[string]bool), } ctx = context.WithValue(ctx, globalResumeInfoKey{}, rInfo) } runCtx, ok := getRunCtx(ctx) if ok { for id_, addr := range id2Addr { if addr.Equals(runCtx.addr) { if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) rInfo.mu.Lock() rInfo.id2StateUsed[id_] = true rInfo.mu.Unlock() } if used, ok := rInfo.id2ResumeDataUsed[id_]; !ok || !used { runCtx.isResumeTarget = true runCtx.resumeData = rInfo.id2ResumeData[id_] rInfo.mu.Lock() rInfo.id2ResumeDataUsed[id_] = true rInfo.mu.Unlock() } break } } } return ctx } func getResumeInfo(ctx context.Context) (*globalResumeInfo, bool) { info, ok := ctx.Value(globalResumeInfoKey{}).(*globalResumeInfo) return info, ok } type InterruptInfo struct { Info any IsRootCause bool } func (i *InterruptInfo) String() string { if i == nil { return "" } return fmt.Sprintf("interrupt info: Info=%v, IsRootCause=%v", i.Info, i.IsRootCause) } ================================================ FILE: internal/core/interrupt.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package core import ( "context" "fmt" "reflect" "github.com/google/uuid" ) type CheckPointStore interface { Get(ctx context.Context, checkPointID string) ([]byte, bool, error) Set(ctx context.Context, checkPointID string, checkPoint []byte) error } type InterruptSignal struct { ID string Address InterruptInfo InterruptState Subs []*InterruptSignal } func (is *InterruptSignal) Error() string { return fmt.Sprintf("interrupt signal: ID=%s, Addr=%s, Info=%s, State=%s, SubsLen=%d", is.ID, is.Address.String(), is.InterruptInfo.String(), is.InterruptState.String(), len(is.Subs)) } type InterruptState struct { State any LayerSpecificPayload any } func (is *InterruptState) String() string { if is == nil { return "" } return fmt.Sprintf("interrupt state: State=%v, LayerSpecificPayload=%v", is.State, is.LayerSpecificPayload) } // InterruptConfig holds optional parameters for creating an interrupt. type InterruptConfig struct { LayerPayload any } // InterruptOption is a function that configures an InterruptConfig. type InterruptOption func(*InterruptConfig) // WithLayerPayload creates an option to attach layer-specific metadata // to the interrupt's state. func WithLayerPayload(payload any) InterruptOption { return func(c *InterruptConfig) { c.LayerPayload = payload } } func Interrupt(ctx context.Context, info any, state any, subContexts []*InterruptSignal, opts ...InterruptOption) ( *InterruptSignal, error) { addr := GetCurrentAddress(ctx) // Apply options to get config config := &InterruptConfig{} for _, opt := range opts { opt(config) } myPoint := InterruptInfo{ Info: info, } if len(subContexts) == 0 { myPoint.IsRootCause = true return &InterruptSignal{ ID: uuid.NewString(), Address: addr, InterruptInfo: myPoint, InterruptState: InterruptState{ State: state, LayerSpecificPayload: config.LayerPayload, }, }, nil } return &InterruptSignal{ ID: uuid.NewString(), Address: addr, InterruptInfo: myPoint, InterruptState: InterruptState{ State: state, LayerSpecificPayload: config.LayerPayload, }, Subs: subContexts, }, nil } // InterruptCtx provides a complete, user-facing context for a single, resumable interrupt point. type InterruptCtx struct { // ID is the unique, fully-qualified address of the interrupt point. // It is constructed by joining the individual Address segments, e.g., "agent:A;node:graph_a;tool:tool_call_123". // This ID should be used when providing resume data via ResumeWithData. ID string // Address is the structured sequence of AddressSegment segments that leads to the interrupt point. Address Address // Info is the user-facing information associated with the interrupt, provided by the component that triggered it. Info any // IsRootCause indicates whether the interrupt point is the exact root cause for an interruption. IsRootCause bool // Parent points to the context of the parent component in the interrupt chain. // It is nil for the top-level interrupt. Parent *InterruptCtx } func (ic *InterruptCtx) EqualsWithoutID(other *InterruptCtx) bool { if ic == nil && other == nil { return true } if ic == nil || other == nil { return false } if !ic.Address.Equals(other.Address) { return false } if ic.IsRootCause != other.IsRootCause { return false } if ic.Info != nil || other.Info != nil { if ic.Info == nil || other.Info == nil { return false } if !reflect.DeepEqual(ic.Info, other.Info) { return false } } if ic.Parent != nil || other.Parent != nil { if ic.Parent == nil || other.Parent == nil { return false } if !ic.Parent.EqualsWithoutID(other.Parent) { return false } } return true } // InterruptContextsProvider is an interface for errors that contain interrupt contexts. // This allows different packages to check for and extract interrupt contexts from errors // without needing to know the concrete error type. type InterruptContextsProvider interface { GetInterruptContexts() []*InterruptCtx } // FromInterruptContexts converts a list of user-facing InterruptCtx objects into an // internal InterruptSignal tree. It correctly handles common ancestors and ensures // that the resulting tree is consistent with the original interrupt chain. // // This method is primarily used by components that bridge different execution environments. // For example, an `adk.AgentTool` might catch an `adk.InterruptInfo`, extract the // `adk.InterruptCtx` objects from it, and then call this method on each one. The resulting // error signals are then typically aggregated into a single error using `compose.CompositeInterrupt` // to be returned from the tool's `InvokableRun` method. // FromInterruptContexts reconstructs a single InterruptSignal tree from a list of // user-facing InterruptCtx objects. It correctly merges common ancestors. func FromInterruptContexts(contexts []*InterruptCtx) *InterruptSignal { if len(contexts) == 0 { return nil } signalMap := make(map[string]*InterruptSignal) var rootSignal *InterruptSignal // getOrCreateSignal is a recursive helper that builds the tree bottom-up. var getOrCreateSignal func(*InterruptCtx) *InterruptSignal getOrCreateSignal = func(ctx *InterruptCtx) *InterruptSignal { if ctx == nil { return nil } // If we've already created a signal for this context, return it. if signal, exists := signalMap[ctx.ID]; exists { return signal } // Create the signal for the current context. newSignal := &InterruptSignal{ ID: ctx.ID, Address: ctx.Address, InterruptInfo: InterruptInfo{ Info: ctx.Info, IsRootCause: ctx.IsRootCause, }, } signalMap[ctx.ID] = newSignal // Cache it immediately. // Recursively ensure the parent exists. If it doesn't, this is the root. if parentSignal := getOrCreateSignal(ctx.Parent); parentSignal != nil { parentSignal.Subs = append(parentSignal.Subs, newSignal) } else { rootSignal = newSignal } return newSignal } // Process all contexts to ensure all branches of the tree are built. for _, ctx := range contexts { _ = getOrCreateSignal(ctx) } return rootSignal } // ToInterruptContexts converts the internal InterruptSignal tree into a list of // user-facing InterruptCtx objects for the root causes of the interruption. // Each returned context has its Parent field populated (if it has a parent), // allowing traversal up the interrupt chain. // // If allowedSegmentTypes is nil, all segment types are kept and addresses are unchanged. // If allowedSegmentTypes is provided, it: // 1. Filters the parent chain to only keep contexts whose leaf segment type is allowed // 2. Strips non-allowed segment types from all addresses func ToInterruptContexts(is *InterruptSignal, allowedSegmentTypes []AddressSegmentType) []*InterruptCtx { if is == nil { return nil } var rootCauseContexts []*InterruptCtx var buildContexts func(*InterruptSignal, *InterruptCtx) buildContexts = func(signal *InterruptSignal, parentCtx *InterruptCtx) { currentCtx := &InterruptCtx{ ID: signal.ID, Address: signal.Address, Info: signal.InterruptInfo.Info, IsRootCause: signal.InterruptInfo.IsRootCause, Parent: parentCtx, } if currentCtx.IsRootCause { rootCauseContexts = append(rootCauseContexts, currentCtx) } for _, subSignal := range signal.Subs { buildContexts(subSignal, currentCtx) } } buildContexts(is, nil) if len(allowedSegmentTypes) > 0 { allowedSet := make(map[AddressSegmentType]bool, len(allowedSegmentTypes)) for _, t := range allowedSegmentTypes { allowedSet[t] = true } for _, ctx := range rootCauseContexts { filterParentChain(ctx, allowedSet) encapsulateContextAddresses(ctx, allowedSet) } } return rootCauseContexts } func filterParentChain(ctx *InterruptCtx, allowedSet map[AddressSegmentType]bool) { if ctx == nil { return } parent := ctx.Parent for parent != nil { if len(parent.Address) > 0 && allowedSet[parent.Address[len(parent.Address)-1].Type] { break } parent = parent.Parent } ctx.Parent = parent filterParentChain(parent, allowedSet) } func encapsulateContextAddresses(ctx *InterruptCtx, allowedSet map[AddressSegmentType]bool) { for c := ctx; c != nil; c = c.Parent { newAddr := make(Address, 0, len(c.Address)) for _, seg := range c.Address { if allowedSet[seg.Type] { newAddr = append(newAddr, seg) } } c.Address = newAddr } } // SignalToPersistenceMaps flattens an InterruptSignal tree into two maps suitable for persistence in a checkpoint. func SignalToPersistenceMaps(is *InterruptSignal) (map[string]Address, map[string]InterruptState) { id2addr := make(map[string]Address) id2state := make(map[string]InterruptState) if is == nil { return id2addr, id2state } var traverse func(*InterruptSignal) traverse = func(signal *InterruptSignal) { // Add current signal's data to the maps. id2addr[signal.ID] = signal.Address id2state[signal.ID] = signal.InterruptState // The embedded struct // Recurse into children. for _, sub := range signal.Subs { traverse(sub) } } traverse(is) return id2addr, id2state } ================================================ FILE: internal/core/interrupt_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package core import ( "context" "testing" "github.com/stretchr/testify/assert" ) // Define AddressSegmentType constants locally to avoid dependency cycles const ( AddressSegmentAgent AddressSegmentType = "agent" AddressSegmentTool AddressSegmentType = "tool" AddressSegmentNode AddressSegmentType = "node" ) func TestInterruptConversion(t *testing.T) { // Test Case 1: Simple Chain (A -> B -> C) t.Run("SimpleChain", func(t *testing.T) { // Manually construct the user-facing contexts with parent pointers ctxA := &InterruptCtx{ID: "A", IsRootCause: false} ctxB := &InterruptCtx{ID: "B", Parent: ctxA, IsRootCause: false} ctxC := &InterruptCtx{ID: "C", Parent: ctxB, IsRootCause: true} // The input to FromInterruptContexts is just the root cause leaf node contexts := []*InterruptCtx{ctxC} // Convert from user-facing contexts to internal signal tree signal := FromInterruptContexts(contexts) // Assertions for the signal tree structure assert.NotNil(t, signal) assert.Equal(t, "A", signal.ID) assert.Len(t, signal.Subs, 1) assert.Equal(t, "B", signal.Subs[0].ID) assert.Len(t, signal.Subs[0].Subs, 1) assert.Equal(t, "C", signal.Subs[0].Subs[0].ID) assert.True(t, signal.Subs[0].Subs[0].IsRootCause) // Convert back from the signal tree to user-facing contexts finalContexts := ToInterruptContexts(signal, nil) // Assertions for the final list of contexts assert.Len(t, finalContexts, 1) finalC := finalContexts[0] assert.Equal(t, "C", finalC.ID) assert.True(t, finalC.IsRootCause) assert.NotNil(t, finalC.Parent) assert.Equal(t, "B", finalC.Parent.ID) assert.NotNil(t, finalC.Parent.Parent) assert.Equal(t, "A", finalC.Parent.Parent.ID) assert.Nil(t, finalC.Parent.Parent.Parent) }) // Test Case 2: Multiple Root Causes with Shared Parent (B -> D, C -> D) t.Run("MultipleRootsSharedParent", func(t *testing.T) { // Manually construct the contexts ctxD := &InterruptCtx{ID: "D", IsRootCause: false} ctxB := &InterruptCtx{ID: "B", Parent: ctxD, IsRootCause: true} ctxC := &InterruptCtx{ID: "C", Parent: ctxD, IsRootCause: true} // The input contains both root cause leaves contexts := []*InterruptCtx{ctxB, ctxC} // Convert to signal tree signal := FromInterruptContexts(contexts) // Assertions for the signal tree structure (should merge at D) assert.NotNil(t, signal) assert.Equal(t, "D", signal.ID) assert.Len(t, signal.Subs, 2) // Order of subs is not guaranteed, so we check for presence subIDs := []string{signal.Subs[0].ID, signal.Subs[1].ID} assert.Contains(t, subIDs, "B") assert.Contains(t, subIDs, "C") // Convert back to user-facing contexts finalContexts := ToInterruptContexts(signal, nil) // Assertions for the final list of contexts assert.Len(t, finalContexts, 2) finalIDs := []string{finalContexts[0].ID, finalContexts[1].ID} assert.Contains(t, finalIDs, "B") assert.Contains(t, finalIDs, "C") // Check parent linking for one of the branches var finalB *InterruptCtx if finalContexts[0].ID == "B" { finalB = finalContexts[0] } else { finalB = finalContexts[1] } assert.NotNil(t, finalB.Parent) assert.Equal(t, "D", finalB.Parent.ID) assert.Nil(t, finalB.Parent.Parent) }) // Test Case 3: Nil and Empty Inputs t.Run("NilAndEmpty", func(t *testing.T) { assert.Nil(t, FromInterruptContexts(nil)) assert.Nil(t, FromInterruptContexts([]*InterruptCtx{})) assert.Nil(t, ToInterruptContexts(nil, nil)) }) } func TestSignalToPersistenceMaps(t *testing.T) { // Test Case 1: Nil Signal t.Run("NilSignal", func(t *testing.T) { id2addr, id2state := SignalToPersistenceMaps(nil) assert.NotNil(t, id2addr) assert.NotNil(t, id2state) assert.Empty(t, id2addr) assert.Empty(t, id2state) }) // Test Case 2: Single Node Signal t.Run("SingleNode", func(t *testing.T) { signal := &InterruptSignal{ ID: "node1", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, }, InterruptState: InterruptState{ State: "test state", LayerSpecificPayload: "test payload", }, } id2addr, id2state := SignalToPersistenceMaps(signal) assert.Len(t, id2addr, 1) assert.Len(t, id2state, 1) assert.Equal(t, signal.Address, id2addr["node1"]) assert.Equal(t, signal.InterruptState, id2state["node1"]) }) // Test Case 3: Simple Tree Structure t.Run("SimpleTree", func(t *testing.T) { child1 := &InterruptSignal{ ID: "child1", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, }, InterruptState: InterruptState{ State: "child1 state", }, } child2 := &InterruptSignal{ ID: "child2", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool2"}, }, InterruptState: InterruptState{ State: "child2 state", }, } parent := &InterruptSignal{ ID: "parent", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, }, InterruptState: InterruptState{ State: "parent state", }, Subs: []*InterruptSignal{child1, child2}, } id2addr, id2state := SignalToPersistenceMaps(parent) // Should contain all 3 nodes assert.Len(t, id2addr, 3) assert.Len(t, id2state, 3) // Check parent node assert.Equal(t, parent.Address, id2addr["parent"]) assert.Equal(t, parent.InterruptState, id2state["parent"]) // Check child nodes assert.Equal(t, child1.Address, id2addr["child1"]) assert.Equal(t, child1.InterruptState, id2state["child1"]) assert.Equal(t, child2.Address, id2addr["child2"]) assert.Equal(t, child2.InterruptState, id2state["child2"]) }) // Test Case 4: Deeply Nested Tree t.Run("DeeplyNestedTree", func(t *testing.T) { leaf1 := &InterruptSignal{ ID: "leaf1", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, {Type: AddressSegmentNode, ID: "node1"}, }, InterruptState: InterruptState{ State: "leaf1 state", }, } leaf2 := &InterruptSignal{ ID: "leaf2", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, {Type: AddressSegmentNode, ID: "node2"}, }, InterruptState: InterruptState{ State: "leaf2 state", }, } middle := &InterruptSignal{ ID: "middle", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, }, InterruptState: InterruptState{ State: "middle state", }, Subs: []*InterruptSignal{leaf1, leaf2}, } root := &InterruptSignal{ ID: "root", Address: Address{ {Type: AddressSegmentAgent, ID: "agent1"}, }, InterruptState: InterruptState{ State: "root state", }, Subs: []*InterruptSignal{middle}, } id2addr, id2state := SignalToPersistenceMaps(root) // Should contain all 4 nodes assert.Len(t, id2addr, 4) assert.Len(t, id2state, 4) // Verify all nodes are present assert.Equal(t, root.Address, id2addr["root"]) assert.Equal(t, root.InterruptState, id2state["root"]) assert.Equal(t, middle.Address, id2addr["middle"]) assert.Equal(t, middle.InterruptState, id2state["middle"]) assert.Equal(t, leaf1.Address, id2addr["leaf1"]) assert.Equal(t, leaf1.InterruptState, id2state["leaf1"]) assert.Equal(t, leaf2.Address, id2addr["leaf2"]) assert.Equal(t, leaf2.InterruptState, id2state["leaf2"]) }) // Test Case 5: Complex Tree with Multiple Branches t.Run("ComplexTree", func(t *testing.T) { // Create a complex tree structure with multiple branches branch1Leaf1 := &InterruptSignal{ID: "b1l1", Address: Address{{Type: AddressSegmentAgent, ID: "a1"}}, InterruptState: InterruptState{State: "b1l1"}} branch1Leaf2 := &InterruptSignal{ID: "b1l2", Address: Address{{Type: AddressSegmentAgent, ID: "a1"}}, InterruptState: InterruptState{State: "b1l2"}} branch1 := &InterruptSignal{ID: "b1", Address: Address{{Type: AddressSegmentAgent, ID: "a1"}}, InterruptState: InterruptState{State: "b1"}, Subs: []*InterruptSignal{branch1Leaf1, branch1Leaf2}} branch2Leaf1 := &InterruptSignal{ID: "b2l1", Address: Address{{Type: AddressSegmentAgent, ID: "a1"}}, InterruptState: InterruptState{State: "b2l1"}} branch2 := &InterruptSignal{ID: "b2", Address: Address{{Type: AddressSegmentAgent, ID: "a1"}}, InterruptState: InterruptState{State: "b2"}, Subs: []*InterruptSignal{branch2Leaf1}} root := &InterruptSignal{ID: "root", Address: Address{{Type: AddressSegmentAgent, ID: "a1"}}, InterruptState: InterruptState{State: "root"}, Subs: []*InterruptSignal{branch1, branch2}} id2addr, id2state := SignalToPersistenceMaps(root) // Should contain all 6 nodes assert.Len(t, id2addr, 6) assert.Len(t, id2state, 6) // Verify all nodes are present expectedNodes := []string{"root", "b1", "b2", "b1l1", "b1l2", "b2l1"} for _, nodeID := range expectedNodes { assert.Contains(t, id2addr, nodeID) assert.Contains(t, id2state, nodeID) } }) // Test Case 6: Empty InterruptState Values t.Run("EmptyInterruptState", func(t *testing.T) { signal := &InterruptSignal{ ID: "node1", Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, InterruptState: InterruptState{ // Empty state values }, } id2addr, id2state := SignalToPersistenceMaps(signal) assert.Len(t, id2addr, 1) assert.Len(t, id2state, 1) assert.Equal(t, signal.Address, id2addr["node1"]) assert.Equal(t, signal.InterruptState, id2state["node1"]) }) } func TestGetCurrentAddress(t *testing.T) { // Test Case 1: No Address in Context t.Run("NoAddressInContext", func(t *testing.T) { ctx := context.Background() addr := GetCurrentAddress(ctx) assert.Nil(t, addr) }) // Test Case 2: Address in Context t.Run("AddressInContext", func(t *testing.T) { ctx := context.Background() expectedAddr := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, } // Create a context with address using internal addrCtx runCtx := &addrCtx{ addr: expectedAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) addr := GetCurrentAddress(ctx) assert.Equal(t, expectedAddr, addr) }) } func TestGetNextResumptionPoints(t *testing.T) { // Test Case 1: No Resume Info in Context t.Run("NoResumeInfo", func(t *testing.T) { ctx := context.Background() _, err := GetNextResumptionPoints(ctx) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to get resume info") }) // Test Case 2: Empty Resume Info t.Run("EmptyResumeInfo", func(t *testing.T) { ctx := context.Background() rInfo := &globalResumeInfo{ id2Addr: make(map[string]Address), } ctx = context.WithValue(ctx, globalResumeInfoKey{}, rInfo) points, err := GetNextResumptionPoints(ctx) assert.NoError(t, err) assert.Empty(t, points) }) // Test Case 3: Valid Resume Points t.Run("ValidResumePoints", func(t *testing.T) { ctx := context.Background() // Set up current address currentAddr := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, } runCtx := &addrCtx{ addr: currentAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) // Set up resume info with child addresses rInfo := &globalResumeInfo{ id2Addr: map[string]Address{ "child1": { {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, }, "child2": { {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool2"}, }, "unrelated": { {Type: AddressSegmentAgent, ID: "agent2"}, }, }, } ctx = context.WithValue(ctx, globalResumeInfoKey{}, rInfo) points, err := GetNextResumptionPoints(ctx) assert.NoError(t, err) assert.Len(t, points, 2) assert.True(t, points["tool1"]) assert.True(t, points["tool2"]) }) // Test Case 4: Root Address (Empty Parent) t.Run("RootAddress", func(t *testing.T) { ctx := context.Background() // Empty current address (root) runCtx := &addrCtx{ addr: Address{}, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) // Set up resume info with various addresses rInfo := &globalResumeInfo{ id2Addr: map[string]Address{ "agent1": { {Type: AddressSegmentAgent, ID: "agent1"}, }, "agent2": { {Type: AddressSegmentAgent, ID: "agent2"}, }, }, } ctx = context.WithValue(ctx, globalResumeInfoKey{}, rInfo) points, err := GetNextResumptionPoints(ctx) assert.NoError(t, err) assert.Len(t, points, 2) assert.True(t, points["agent1"]) assert.True(t, points["agent2"]) }) } func TestBatchResumeWithData(t *testing.T) { // Test Case 1: New Resume Data t.Run("NewResumeData", func(t *testing.T) { ctx := context.Background() resumeData := map[string]any{ "id1": "data1", "id2": "data2", } newCtx := BatchResumeWithData(ctx, resumeData) // Verify the data was set correctly rInfo, ok := newCtx.Value(globalResumeInfoKey{}).(*globalResumeInfo) assert.True(t, ok) assert.NotNil(t, rInfo) assert.Equal(t, "data1", rInfo.id2ResumeData["id1"]) assert.Equal(t, "data2", rInfo.id2ResumeData["id2"]) }) // Test Case 2: Merge with Existing Resume Data t.Run("MergeWithExisting", func(t *testing.T) { ctx := context.Background() // First call with initial data initialData := map[string]any{ "id1": "initial", } ctx = BatchResumeWithData(ctx, initialData) // Second call with additional data additionalData := map[string]any{ "id2": "additional", } newCtx := BatchResumeWithData(ctx, additionalData) // Verify both data sets are present rInfo, ok := newCtx.Value(globalResumeInfoKey{}).(*globalResumeInfo) assert.True(t, ok) assert.NotNil(t, rInfo) assert.Equal(t, "initial", rInfo.id2ResumeData["id1"]) assert.Equal(t, "additional", rInfo.id2ResumeData["id2"]) }) // Test Case 3: Empty Resume Data t.Run("EmptyResumeData", func(t *testing.T) { ctx := context.Background() newCtx := BatchResumeWithData(ctx, map[string]any{}) rInfo, ok := newCtx.Value(globalResumeInfoKey{}).(*globalResumeInfo) assert.True(t, ok) assert.NotNil(t, rInfo) assert.Empty(t, rInfo.id2ResumeData) }) } func TestGetInterruptState(t *testing.T) { // Test Case 1: No Interrupt State t.Run("NoInterruptState", func(t *testing.T) { ctx := context.Background() wasInterrupted, hasState, state := GetInterruptState[string](ctx) assert.False(t, wasInterrupted) assert.False(t, hasState) assert.Equal(t, "", state) }) // Test Case 2: With Interrupt State t.Run("WithInterruptState", func(t *testing.T) { ctx := context.Background() // Create a context with interrupt state runCtx := &addrCtx{ interruptState: &InterruptState{ State: "test state", }, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) wasInterrupted, hasState, state := GetInterruptState[string](ctx) assert.True(t, wasInterrupted) assert.True(t, hasState) assert.Equal(t, "test state", state) }) // Test Case 3: Wrong Type for Interrupt State t.Run("WrongType", func(t *testing.T) { ctx := context.Background() // Create a context with interrupt state of wrong type runCtx := &addrCtx{ interruptState: &InterruptState{ State: 123, // int instead of string }, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) wasInterrupted, hasState, state := GetInterruptState[string](ctx) assert.True(t, wasInterrupted) assert.False(t, hasState) // Should be false due to type mismatch assert.Equal(t, "", state) }) // Test Case 4: Nil Interrupt State t.Run("NilInterruptState", func(t *testing.T) { ctx := context.Background() // Create a context with nil interrupt state runCtx := &addrCtx{ interruptState: nil, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) wasInterrupted, hasState, state := GetInterruptState[string](ctx) assert.False(t, wasInterrupted) // Should be false because interruptState is nil assert.False(t, hasState) // Should be false because state is nil assert.Equal(t, "", state) }) } func TestGetResumeContext(t *testing.T) { // Test Case 1: Not Resume Target t.Run("NotResumeTarget", func(t *testing.T) { ctx := context.Background() isResumeTarget, hasData, data := GetResumeContext[string](ctx) assert.False(t, isResumeTarget) assert.False(t, hasData) assert.Equal(t, "", data) }) // Test Case 2: Resume Target with Data t.Run("ResumeTargetWithData", func(t *testing.T) { ctx := context.Background() // Create a context as resume target with data runCtx := &addrCtx{ isResumeTarget: true, resumeData: "resume data", } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) isResumeTarget, hasData, data := GetResumeContext[string](ctx) assert.True(t, isResumeTarget) assert.True(t, hasData) assert.Equal(t, "resume data", data) }) // Test Case 3: Resume Target without Data t.Run("ResumeTargetWithoutData", func(t *testing.T) { ctx := context.Background() // Create a context as resume target without data runCtx := &addrCtx{ isResumeTarget: true, resumeData: nil, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) isResumeTarget, hasData, data := GetResumeContext[string](ctx) assert.True(t, isResumeTarget) assert.False(t, hasData) assert.Equal(t, "", data) }) // Test Case 4: Wrong Type for Resume Data t.Run("WrongType", func(t *testing.T) { ctx := context.Background() // Create a context with resume data of wrong type runCtx := &addrCtx{ isResumeTarget: true, resumeData: 123, // int instead of string } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) isResumeTarget, hasData, data := GetResumeContext[string](ctx) assert.True(t, isResumeTarget) assert.False(t, hasData) // Should be false due to type mismatch assert.Equal(t, "", data) }) } func TestWithLayerPayload(t *testing.T) { // Test Case 1: Basic Usage t.Run("BasicUsage", func(t *testing.T) { config := &InterruptConfig{} opt := WithLayerPayload("test payload") opt(config) assert.Equal(t, "test payload", config.LayerPayload) }) // Test Case 2: Nil Payload t.Run("NilPayload", func(t *testing.T) { config := &InterruptConfig{LayerPayload: "existing"} opt := WithLayerPayload(nil) opt(config) assert.Nil(t, config.LayerPayload) }) // Test Case 3: Complex Payload t.Run("ComplexPayload", func(t *testing.T) { config := &InterruptConfig{} payload := map[string]any{ "key1": "value1", "key2": 123, } opt := WithLayerPayload(payload) opt(config) assert.Equal(t, payload, config.LayerPayload) }) } func TestInterruptFunction(t *testing.T) { // Test Case 1: Simple Interrupt without SubContexts t.Run("SimpleInterrupt", func(t *testing.T) { ctx := context.Background() // Create a context with a mock address expectedAddr := Address{{Type: AddressSegmentAgent, ID: "test-agent"}} runCtx := &addrCtx{ addr: expectedAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) info := "test info" state := "test state" signal, err := Interrupt(ctx, info, state, nil) assert.NoError(t, err) assert.NotNil(t, signal) assert.NotEmpty(t, signal.ID) assert.Equal(t, info, signal.Info) assert.Equal(t, state, signal.State) assert.True(t, signal.IsRootCause) assert.Equal(t, expectedAddr, signal.Address) }) // Test Case 2: Interrupt with SubContexts t.Run("InterruptWithSubContexts", func(t *testing.T) { ctx := context.Background() // Create a context with a mock address expectedAddr := Address{{Type: AddressSegmentAgent, ID: "parent-agent"}} runCtx := &addrCtx{ addr: expectedAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) // Create sub contexts subContexts := []*InterruptSignal{ { ID: "child1", Address: Address{{Type: AddressSegmentAgent, ID: "child1"}}, }, { ID: "child2", Address: Address{{Type: AddressSegmentAgent, ID: "child2"}}, }, } info := "parent info" state := "parent state" signal, err := Interrupt(ctx, info, state, subContexts) assert.NoError(t, err) assert.NotNil(t, signal) assert.NotEmpty(t, signal.ID) assert.Equal(t, info, signal.Info) assert.Equal(t, state, signal.State) assert.False(t, signal.IsRootCause) // Should be false when there are sub contexts assert.Len(t, signal.Subs, 2) assert.Equal(t, "child1", signal.Subs[0].ID) assert.Equal(t, "child2", signal.Subs[1].ID) }) // Test Case 3: Interrupt with Options t.Run("InterruptWithOptions", func(t *testing.T) { ctx := context.Background() // Create a context with a mock address expectedAddr := Address{{Type: AddressSegmentAgent, ID: "test-agent"}} runCtx := &addrCtx{ addr: expectedAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) info := "test info" state := "test state" layerPayload := "layer payload" signal, err := Interrupt(ctx, info, state, nil, WithLayerPayload(layerPayload)) assert.NoError(t, err) assert.NotNil(t, signal) assert.Equal(t, layerPayload, signal.LayerSpecificPayload) }) // Test Case 4: Empty SubContexts t.Run("EmptySubContexts", func(t *testing.T) { ctx := context.Background() // Create a context with a mock address expectedAddr := Address{{Type: AddressSegmentAgent, ID: "test-agent"}} runCtx := &addrCtx{ addr: expectedAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) info := "test info" state := "test state" signal, err := Interrupt(ctx, info, state, []*InterruptSignal{}) assert.NoError(t, err) assert.NotNil(t, signal) assert.True(t, signal.IsRootCause) // Should be true when sub contexts is empty assert.Empty(t, signal.Subs) }) } func TestAddressMethods(t *testing.T) { // Test Case 1: Address.String() t.Run("AddressString", func(t *testing.T) { addr := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, {Type: AddressSegmentNode, ID: "node1", SubID: "sub1"}, } result := addr.String() expected := "agent:agent1;tool:tool1;node:node1:sub1" assert.Equal(t, expected, result) }) // Test Case 2: Address.String() with empty address t.Run("EmptyAddressString", func(t *testing.T) { var addr Address result := addr.String() assert.Equal(t, "", result) }) // Test Case 3: Address.Equals() with equal addresses t.Run("AddressEquals", func(t *testing.T) { addr1 := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, } addr2 := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, } assert.True(t, addr1.Equals(addr2)) }) // Test Case 4: Address.Equals() with different addresses t.Run("AddressNotEquals", func(t *testing.T) { addr1 := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, } addr2 := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool2"}, } assert.False(t, addr1.Equals(addr2)) }) // Test Case 5: Address.Equals() with different lengths t.Run("AddressDifferentLengths", func(t *testing.T) { addr1 := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, {Type: AddressSegmentTool, ID: "tool1"}, } addr2 := Address{ {Type: AddressSegmentAgent, ID: "agent1"}, } assert.False(t, addr1.Equals(addr2)) }) // Test Case 6: Address.Equals() with SubID differences t.Run("AddressSubIDDifference", func(t *testing.T) { addr1 := Address{ {Type: AddressSegmentAgent, ID: "agent1", SubID: "sub1"}, } addr2 := Address{ {Type: AddressSegmentAgent, ID: "agent1", SubID: "sub2"}, } assert.False(t, addr1.Equals(addr2)) }) } func TestAppendAddressSegment(t *testing.T) { // Test Case 1: Append to empty address t.Run("AppendToEmpty", func(t *testing.T) { ctx := context.Background() newCtx := AppendAddressSegment(ctx, AddressSegmentAgent, "agent1", "") addr := GetCurrentAddress(newCtx) assert.Len(t, addr, 1) assert.Equal(t, AddressSegmentAgent, addr[0].Type) assert.Equal(t, "agent1", addr[0].ID) assert.Equal(t, "", addr[0].SubID) }) // Test Case 2: Append to existing address t.Run("AppendToExisting", func(t *testing.T) { ctx := context.Background() // First append ctx = AppendAddressSegment(ctx, AddressSegmentAgent, "agent1", "") // Second append newCtx := AppendAddressSegment(ctx, AddressSegmentTool, "tool1", "call1") addr := GetCurrentAddress(newCtx) assert.Len(t, addr, 2) assert.Equal(t, AddressSegmentAgent, addr[0].Type) assert.Equal(t, "agent1", addr[0].ID) assert.Equal(t, AddressSegmentTool, addr[1].Type) assert.Equal(t, "tool1", addr[1].ID) assert.Equal(t, "call1", addr[1].SubID) }) // Test Case 3: Append with SubID t.Run("AppendWithSubID", func(t *testing.T) { ctx := context.Background() newCtx := AppendAddressSegment(ctx, AddressSegmentTool, "tool1", "call123") addr := GetCurrentAddress(newCtx) assert.Len(t, addr, 1) assert.Equal(t, AddressSegmentTool, addr[0].Type) assert.Equal(t, "tool1", addr[0].ID) assert.Equal(t, "call123", addr[0].SubID) }) } func TestPopulateInterruptState(t *testing.T) { // Test Case 1: Populate with matching address t.Run("PopulateMatchingAddress", func(t *testing.T) { ctx := context.Background() // Set up current address currentAddr := Address{{Type: AddressSegmentAgent, ID: "agent1"}} runCtx := &addrCtx{ addr: currentAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) // Set up interrupt state data id2Addr := map[string]Address{ "interrupt1": currentAddr, } id2State := map[string]InterruptState{ "interrupt1": {State: "test state"}, } newCtx := PopulateInterruptState(ctx, id2Addr, id2State) // Verify the state was populated wasInterrupted, hasState, state := GetInterruptState[string](newCtx) assert.True(t, wasInterrupted) assert.True(t, hasState) assert.Equal(t, "test state", state) }) // Test Case 2: Populate with non-matching address t.Run("PopulateNonMatchingAddress", func(t *testing.T) { ctx := context.Background() // Set up current address currentAddr := Address{{Type: AddressSegmentAgent, ID: "agent1"}} runCtx := &addrCtx{ addr: currentAddr, } ctx = context.WithValue(ctx, addrCtxKey{}, runCtx) // Set up interrupt state data with different address id2Addr := map[string]Address{ "interrupt1": {{Type: AddressSegmentAgent, ID: "agent2"}}, } id2State := map[string]InterruptState{ "interrupt1": {State: "test state"}, } newCtx := PopulateInterruptState(ctx, id2Addr, id2State) // Verify the state was NOT populated (no matching address) wasInterrupted, hasState, state := GetInterruptState[string](newCtx) assert.False(t, wasInterrupted) assert.False(t, hasState) assert.Equal(t, "", state) }) // Test Case 3: Populate with empty data t.Run("PopulateEmptyData", func(t *testing.T) { ctx := context.Background() newCtx := PopulateInterruptState(ctx, map[string]Address{}, map[string]InterruptState{}) // Verify no state was populated wasInterrupted, hasState, state := GetInterruptState[string](newCtx) assert.False(t, wasInterrupted) assert.False(t, hasState) assert.Equal(t, "", state) }) } func TestStringMethods(t *testing.T) { // Test Case 1: InterruptSignal.Error() t.Run("InterruptSignalError", func(t *testing.T) { signal := &InterruptSignal{ ID: "test-id", Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, InterruptInfo: InterruptInfo{ Info: "test info", }, InterruptState: InterruptState{ State: "test state", LayerSpecificPayload: "test payload", }, Subs: []*InterruptSignal{ {ID: "sub1"}, }, } errorStr := signal.Error() expectedContains := []string{ "interrupt signal:", "ID=test-id", "Addr=agent:agent1", "Info=interrupt info: Info=test info, IsRootCause=false", "State=interrupt state: State=test state, LayerSpecificPayload=test payload", "SubsLen=1", } for _, expected := range expectedContains { assert.Contains(t, errorStr, expected) } }) // Test Case 2: InterruptState.String() t.Run("InterruptStateString", func(t *testing.T) { state := &InterruptState{ State: "test state", LayerSpecificPayload: "test payload", } result := state.String() expected := "interrupt state: State=test state, LayerSpecificPayload=test payload" assert.Equal(t, expected, result) }) // Test Case 3: InterruptState.String() with nil t.Run("InterruptStateStringNil", func(t *testing.T) { var state *InterruptState result := state.String() assert.Equal(t, "", result) }) // Test Case 4: InterruptInfo.String() t.Run("InterruptInfoString", func(t *testing.T) { info := &InterruptInfo{ Info: "test info", IsRootCause: true, } result := info.String() expected := "interrupt info: Info=test info, IsRootCause=true" assert.Equal(t, expected, result) }) // Test Case 5: InterruptInfo.String() with nil t.Run("InterruptInfoStringNil", func(t *testing.T) { var info *InterruptInfo result := info.String() assert.Equal(t, "", result) }) } func TestInterruptCtxEqualsWithoutID(t *testing.T) { // Test Case 1: Equal contexts t.Run("EqualContexts", func(t *testing.T) { ctx1 := &InterruptCtx{ ID: "id1", Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Info: "info1", IsRootCause: true, } ctx2 := &InterruptCtx{ ID: "id2", // Different ID should be ignored Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Info: "info1", IsRootCause: true, } assert.True(t, ctx1.EqualsWithoutID(ctx2)) }) // Test Case 2: Different addresses t.Run("DifferentAddresses", func(t *testing.T) { ctx1 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, } ctx2 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent2"}}, } assert.False(t, ctx1.EqualsWithoutID(ctx2)) }) // Test Case 3: Different root cause flags t.Run("DifferentRootCause", func(t *testing.T) { ctx1 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, IsRootCause: true, } ctx2 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, IsRootCause: false, } assert.False(t, ctx1.EqualsWithoutID(ctx2)) }) // Test Case 4: Different info t.Run("DifferentInfo", func(t *testing.T) { ctx1 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Info: "info1", } ctx2 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Info: "info2", } assert.False(t, ctx1.EqualsWithoutID(ctx2)) }) // Test Case 5: Nil contexts t.Run("NilContexts", func(t *testing.T) { var ctx1 *InterruptCtx var ctx2 *InterruptCtx assert.True(t, ctx1.EqualsWithoutID(ctx2)) ctx3 := &InterruptCtx{} assert.False(t, ctx1.EqualsWithoutID(ctx3)) assert.False(t, ctx3.EqualsWithoutID(ctx1)) }) // Test Case 6: With parent contexts t.Run("WithParentContexts", func(t *testing.T) { parent1 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "parent"}}, } parent2 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "parent"}}, } ctx1 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Parent: parent1, } ctx2 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Parent: parent2, } assert.True(t, ctx1.EqualsWithoutID(ctx2)) }) // Test Case 7: Different parent contexts t.Run("DifferentParentContexts", func(t *testing.T) { parent1 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "parent1"}}, } parent2 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "parent2"}}, } ctx1 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Parent: parent1, } ctx2 := &InterruptCtx{ Address: Address{{Type: AddressSegmentAgent, ID: "agent1"}}, Parent: parent2, } assert.False(t, ctx1.EqualsWithoutID(ctx2)) }) } ================================================ FILE: internal/core/resume.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package core import "context" // GetInterruptState provides a type-safe way to check for and retrieve the persisted state from a previous interruption. // It is the primary function a component should use to understand its past state. // // It returns three values: // - wasInterrupted (bool): True if the node was part of a previous interruption, regardless of whether state was provided. // - state (T): The typed state object, if it was provided and matches type `T`. // - hasState (bool): True if state was provided during the original interrupt and successfully cast to type `T`. func GetInterruptState[T any](ctx context.Context) (wasInterrupted bool, hasState bool, state T) { rCtx, ok := getRunCtx(ctx) if !ok || rCtx.interruptState == nil { return } wasInterrupted = true if rCtx.interruptState.State == nil { return } state, hasState = rCtx.interruptState.State.(T) return } // GetResumeContext checks if the current component is the target of a resume operation // and retrieves any data provided by the user for that resumption. // // This function is typically called *after* a component has already determined it is in a // resumed state by calling GetInterruptState. // // It returns three values: // - isResumeTarget: A boolean that is true if the current component's address OR any of its // descendant addresses was explicitly targeted by a call to Resume() or ResumeWithData(). // This allows composite components (like tools containing nested graphs) to know they should // execute their children to reach the actual resume target. // - hasData: A boolean that is true if data was provided for this specific component (i.e., not nil). // - data: The typed data provided by the user. // // ### How to Use This Function: A Decision Framework // // The correct usage pattern depends on the application's desired resume strategy. // // #### Strategy 1: Implicit "Resume All" // In some use cases, any resume operation implies that *all* interrupted points should proceed. // For example, if an application's UI only provides a single "Continue" button for a set of // interruptions. In this model, a component can often just use `GetInterruptState` to see if // `wasInterrupted` is true and then proceed with its logic, as it can assume it is an intended target. // It may still call `GetResumeContext` to check for optional data, but the `isResumeFlow` flag is less critical. // // #### Strategy 2: Explicit "Targeted Resume" (Most Common) // For applications with multiple, distinct interrupt points that must be resumed independently, it is // crucial to differentiate which point is being resumed. This is the primary use case for the `isResumeTarget` flag. // - If `isResumeTarget` is `true`: Your component (or one of its descendants) is the target. // If `hasData` is true, you are the direct target and should consume the data. // If `hasData` is false, a descendant is the target—execute your children to reach it. // - If `isResumeTarget` is `false`: Neither you nor your descendants are the target. You MUST // re-interrupt (e.g., by returning `StatefulInterrupt(...)`) to preserve your state. // // ### Guidance for Composite Components // // Composite components (like `Graph` or other `Runnable`s that contain sub-processes) have a dual role: // 1. Check for Self-Targeting: A composite component can itself be the target of a resume // operation, for instance, to modify its internal state. It may call `GetResumeContext` // to check for data targeted at its own address. // 2. Act as a Conduit: After checking for itself, its primary role is to re-execute its children, // allowing the resume context to flow down to them. It must not consume a resume signal // intended for one of its descendants. func GetResumeContext[T any](ctx context.Context) (isResumeTarget bool, hasData bool, data T) { rCtx, ok := getRunCtx(ctx) if !ok { return } isResumeTarget = rCtx.isResumeTarget if !isResumeTarget { return } // It is a resume flow, now check for data if rCtx.resumeData == nil { return // hasData is false } data, hasData = rCtx.resumeData.(T) return } func getRunCtx(ctx context.Context) (*addrCtx, bool) { rCtx, ok := ctx.Value(addrCtxKey{}).(*addrCtx) return rCtx, ok } ================================================ FILE: internal/generic/generic.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generic import ( "reflect" ) // NewInstance create an instance of the given type T. // the main purpose of this function is to create an instance of a type, can handle the type of T is a pointer or not. // eg. NewInstance[int] returns 0. // eg. NewInstance[*int] returns *0 (will be ptr of 0, not nil!). func NewInstance[T any]() T { typ := TypeOf[T]() switch typ.Kind() { case reflect.Map: return reflect.MakeMap(typ).Interface().(T) case reflect.Slice, reflect.Array: return reflect.MakeSlice(typ, 0, 0).Interface().(T) case reflect.Ptr: typ = typ.Elem() origin := reflect.New(typ) inst := origin for typ.Kind() == reflect.Ptr { typ = typ.Elem() inst = inst.Elem() inst.Set(reflect.New(typ)) } return origin.Interface().(T) default: var t T return t } } // TypeOf returns the type of T. // eg. TypeOf[int] returns reflect.TypeOf(int). // eg. TypeOf[*int] returns reflect.TypeOf(*int). func TypeOf[T any]() reflect.Type { return reflect.TypeOf((*T)(nil)).Elem() } // PtrOf returns a pointer of T. // useful when you want to get a pointer of a value, in some config, for example. // eg. PtrOf[int] returns *int. // eg. PtrOf[*int] returns **int. func PtrOf[T any](v T) *T { return &v } type Pair[F, S any] struct { First F Second S } // Reverse returns a new slice with elements in reversed order. func Reverse[S ~[]E, E any](s S) S { d := make(S, len(s)) for i := 0; i < len(s); i++ { d[i] = s[len(s)-i-1] } return d } // CopyMap copies a map to a new map. func CopyMap[K comparable, V any](src map[K]V) map[K]V { dst := make(map[K]V, len(src)) for k, v := range src { dst[k] = v } return dst } ================================================ FILE: internal/generic/generic_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generic import ( "testing" "github.com/stretchr/testify/assert" ) func TestNewInstance(t *testing.T) { t.Run("struct", func(t *testing.T) { type Test struct{} inst := NewInstance[Test]() assert.IsType(t, Test{}, inst) }) t.Run("pointer", func(t *testing.T) { type Test struct{} inst := NewInstance[*Test]() assert.IsType(t, &Test{}, inst) }) t.Run("interface", func(t *testing.T) { type Test interface{} inst := NewInstance[Test]() assert.IsType(t, Test(nil), inst) }) t.Run("pointer of pointer of pointer", func(t *testing.T) { type Test struct { Value int } inst := NewInstance[***Test]() ptr := &Test{} ptrOfPtr := &ptr assert.NotNil(t, inst) assert.NotNil(t, *inst) assert.IsType(t, ptrOfPtr, *inst) assert.NotNil(t, **inst) assert.Equal(t, Test{Value: 0}, ***inst) }) t.Run("primitive_map", func(t *testing.T) { inst := NewInstance[map[string]any]() assert.NotNil(t, inst) inst["a"] = 1 assert.Equal(t, map[string]any{"a": 1}, inst) }) t.Run("primitive_slice", func(t *testing.T) { inst := NewInstance[[]int]() assert.NotNil(t, inst) inst = append(inst, 1) assert.Equal(t, []int{1}, inst) }) t.Run("primitive_string", func(t *testing.T) { inst := NewInstance[string]() assert.Equal(t, "", inst) }) t.Run("primitive_int64", func(t *testing.T) { inst := NewInstance[int64]() assert.Equal(t, int64(0), inst) }) } func TestReverse(t *testing.T) { t.Run("reverse int slice", func(t *testing.T) { input := []int{1, 2, 3, 4, 5} expected := []int{5, 4, 3, 2, 1} result := Reverse(input) assert.Equal(t, expected, result) }) t.Run("reverse string slice", func(t *testing.T) { input := []string{"a", "b", "c"} expected := []string{"c", "b", "a"} result := Reverse(input) assert.Equal(t, expected, result) }) } ================================================ FILE: internal/generic/type_name.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generic import ( "reflect" "regexp" "runtime" "strings" ) var ( regOfAnonymousFunc = regexp.MustCompile(`^func[0-9]+`) regOfNumber = regexp.MustCompile(`^\d+$`) ) // ParseTypeName returns the name of the type of the given value. // It takes a reflect.Value as input and processes it to determine the underlying type. If the type is a pointer, it dereferences it to get the actual type. (the optimization of this function) // eg: ParseTypeName(reflect.ValueOf(&&myStruct{})) returns "myStruct" (not "**myStruct") // // If the type is a function, it retrieves the function's name, handling both named and anonymous functions. // examples of function paths: [package_path].[receiver_type].[func_name] // named function: xxx/utils.ParseTypeName // method: xxx/utils.(*MyStruct).Method // anonymous function: xxx/utils.TestParseTypeName.func6.1 func ParseTypeName(val reflect.Value) string { typ := val.Type() for typ.Kind() == reflect.Pointer { typ = typ.Elem() } if typ.Kind() == reflect.Func { funcName := runtime.FuncForPC(val.Pointer()).Name() idx := strings.LastIndex(funcName, ".") if idx < 0 { if funcName != "" { return funcName } return "" } name := funcName[idx+1:] if regOfAnonymousFunc.MatchString(name) { return "" } if regOfNumber.MatchString(name) { return "" } return name } return typ.Name() } ================================================ FILE: internal/generic/type_name_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generic import ( "reflect" "testing" "github.com/stretchr/testify/assert" ) func TestParseTypeName(t *testing.T) { t.Run("named_struct", func(t *testing.T) { type OpenAI struct{} model := &OpenAI{} name := ParseTypeName(reflect.Indirect(reflect.ValueOf(model))) assert.Equal(t, "OpenAI", name) }) t.Run("anonymous_struct", func(t *testing.T) { model := &struct{}{} name := ParseTypeName(reflect.ValueOf(model)) assert.Equal(t, "", name) }) t.Run("anonymous_struct_from_func", func(t *testing.T) { model := genStruct() name := ParseTypeName(reflect.ValueOf(model)) assert.Equal(t, "", name) }) t.Run("named_interface", func(t *testing.T) { type OpenAI interface{} model := OpenAI(&struct{}{}) name := ParseTypeName(reflect.ValueOf(model)) assert.Equal(t, "", name) name = ParseTypeName(reflect.ValueOf((*OpenAI)(nil))) assert.Equal(t, "OpenAI", name) }) t.Run("named_function", func(t *testing.T) { f := genOpenAI name := ParseTypeName(reflect.ValueOf(f)) assert.Equal(t, "genOpenAI", name) }) t.Run("anonymous_function", func(t *testing.T) { f := genAnonymousFunc() name := ParseTypeName(reflect.ValueOf(f)) assert.Equal(t, "", name) ff := func(n string) { _ = n } name = ParseTypeName(reflect.ValueOf(ff)) assert.Equal(t, "", name) }) } func genStruct() *struct{} { return &struct{}{} } func genOpenAI() {} func genAnonymousFunc() func(n string) { return func(n string) { _ = n } } ================================================ FILE: internal/gmap/gmap.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package gmap // Concat returns the unions of maps as a new map. // // 💡 NOTE: // // - Once the key conflicts, the newer value always replace the older one ([DiscardOld]), // - If the result is an empty set, always return an empty map instead of nil // // 🚀 EXAMPLE: // // m := map[int]int{1: 1, 2: 2} // Concat(m, nil) ⏩ map[int]int{1: 1, 2: 2} // Concat(m, map[int]{3: 3}) ⏩ map[int]int{1: 1, 2: 2, 3: 3} // Concat(m, map[int]{2: -1}) ⏩ map[int]int{1: 1, 2: -1} // "2:2" is replaced by the newer "2:-1" // // 💡 AKA: Merge, Union, Combine func Concat[K comparable, V any](ms ...map[K]V) map[K]V { // FastPath: no map or only one map given. if len(ms) == 0 { return make(map[K]V) } if len(ms) == 1 { return cloneWithoutNilCheck(ms[0]) } var maxLen int for _, m := range ms { if len(m) > maxLen { maxLen = len(m) } } ret := make(map[K]V, maxLen) // FastPath: all maps are empty. if maxLen == 0 { return ret } // Concat all maps. for _, m := range ms { for k, v := range m { ret[k] = v } } return ret } // Map applies function f to each key and value of map m. // Results of f are returned as a new map. // // 🚀 EXAMPLE: // // f := func(k, v int) (string, string) { return strconv.Itoa(k), strconv.Itoa(v) } // Map(map[int]int{1: 1}, f) ⏩ map[string]string{"1": "1"} // Map(map[int]int{}, f) ⏩ map[string]string{} func Map[K1, K2 comparable, V1, V2 any](m map[K1]V1, f func(K1, V1) (K2, V2)) map[K2]V2 { r := make(map[K2]V2, len(m)) for k, v := range m { k2, v2 := f(k, v) r[k2] = v2 } return r } // Values returns the values of the map m. // // 🚀 EXAMPLE: // // m := map[int]string{1: "1", 2: "2", 3: "3", 4: "4"} // Values(m) ⏩ []string{"1", "4", "2", "3"} //⚠️INDETERMINATE ORDER⚠️ // // ⚠️ WARNING: The keys values be in an indeterminate order, func Values[K comparable, V any](m map[K]V) []V { r := make([]V, 0, len(m)) for _, v := range m { r = append(r, v) } return r } // Clone returns a shallow copy of map. // If the given map is nil, nil is returned. // // 🚀 EXAMPLE: // // Clone(map[int]int{1: 1, 2: 2}) ⏩ map[int]int{1: 1, 2: 2} // Clone(map[int]int{}) ⏩ map[int]int{} // Clone[int, int](nil) ⏩ nil // // 💡 HINT: Both keys and values are copied using assignment (=), so this is a shallow clone. // 💡 AKA: Copy func Clone[K comparable, V any, M ~map[K]V](m M) M { if m == nil { return nil } return cloneWithoutNilCheck(m) } func cloneWithoutNilCheck[K comparable, V any, M ~map[K]V](m M) M { r := make(M, len(m)) for k, v := range m { r[k] = v } return r } ================================================ FILE: internal/gmap/gmap_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package gmap import ( "fmt" "sort" "strconv" "testing" "github.com/stretchr/testify/assert" ) func TestMerge(t *testing.T) { assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, Concat(map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, nil)) assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, Concat[int, int](nil, map[int]int{1: 1, 2: 2, 3: 3, 4: 4})) assert.Equal(t, map[int]int{}, Concat[int, int](nil, nil)) assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, Concat(map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, map[int]int{1: 1, 2: 2, 3: 3, 4: 4})) assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, Concat(map[int]int{1: 0, 2: 0}, map[int]int{1: 1, 2: 2, 3: 3, 4: 4})) assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, Concat(map[int]int{1: 1, 2: 1}, map[int]int{2: 2, 3: 3, 4: 4})) } func TestMap(t *testing.T) { assert.Equal(t, map[string]string{"1": "1", "2": "2"}, Map(map[int]int{1: 1, 2: 2}, func(k, v int) (string, string) { return strconv.Itoa(k), strconv.Itoa(v) })) assert.Equal(t, map[string]string{}, Map(map[int]int{}, func(k, v int) (string, string) { return strconv.Itoa(k), strconv.Itoa(v) })) } func TestValues(t *testing.T) { { keys := Values(map[int]string{1: "1", 2: "2", 3: "3", 4: "4"}) sort.Strings(keys) assert.Equal(t, []string{"1", "2", "3", "4"}, keys) } assert.Equal(t, []string{}, Values(map[int]string{})) assert.Equal(t, []string{}, Values[int, string](nil)) } func TestClone(t *testing.T) { assert.Equal(t, map[int]int{1: 1, 2: 2}, Clone(map[int]int{1: 1, 2: 2})) var nilMap map[int]int assert.Equal(t, map[int]int{}, Clone(map[int]int{})) assert.NotEqual(t, (map[int]int)(nil), Clone(map[int]int{})) assert.Equal(t, (map[int]int)(nil), Clone(nilMap)) assert.NotEqual(t, map[int]int{}, Clone(nilMap)) // Test new type. type I2I map[int]int assert.Equal(t, I2I{1: 1, 2: 2}, Clone(I2I{1: 1, 2: 2})) assert.Equal(t, "gmap.I2I", fmt.Sprintf("%T", Clone(I2I{}))) // Test shallow clone. src := map[int]*int{1: ptr(1), 2: ptr(2)} dst := Clone(src) assert.Equal(t, src, dst) assert.True(t, src[1] == dst[1]) assert.True(t, src[2] == dst[2]) } // Ptr returns a pointer to the given value. func ptr[T any](v T) *T { return &v } ================================================ FILE: internal/gslice/gslice.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package gslice // ToMap collects elements of slice to map, both map keys and values are produced // by mapping function f. // // 🚀 EXAMPLE: // // type Foo struct { // ID int // Name string // } // mapper := func(f Foo) (int, string) { return f.ID, f.Name } // ToMap([]Foo{}, mapper) ⏩ map[int]string{} // s := []Foo{{1, "one"}, {2, "two"}, {3, "three"}} // ToMap(s, mapper) ⏩ map[int]string{1: "one", 2: "two", 3: "three"} func ToMap[T, V any, K comparable](s []T, f func(T) (K, V)) map[K]V { m := make(map[K]V, len(s)) for _, e := range s { k, v := f(e) m[k] = v } return m } ================================================ FILE: internal/gslice/gslice_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package gslice import ( "testing" "github.com/stretchr/testify/assert" ) func TestToMap(t *testing.T) { type Foo struct { ID int Name string } mapper := func(f Foo) (int, string) { return f.ID, f.Name } assert.Equal(t, map[int]string{}, ToMap([]Foo{}, mapper)) assert.Equal(t, map[int]string{}, ToMap(nil, mapper)) assert.Equal(t, map[int]string{1: "one", 2: "two", 3: "three"}, ToMap([]Foo{{1, "one"}, {2, "two"}, {3, "three"}}, mapper)) } ================================================ FILE: internal/merge.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package internal import ( "fmt" "reflect" "github.com/cloudwego/eino/internal/generic" ) var mergeFuncs = map[reflect.Type]any{} func RegisterValuesMergeFunc[T any](fn func([]T) (T, error)) { mergeFuncs[generic.TypeOf[T]()] = fn } func GetMergeFunc(typ reflect.Type) func([]any) (any, error) { if fn, ok := mergeFuncs[typ]; ok { return func(vs []any) (any, error) { rvs := reflect.MakeSlice(reflect.SliceOf(typ), 0, len(vs)) for _, v := range vs { if t := reflect.TypeOf(v); t != typ { return nil, fmt.Errorf( "(values merge) field type mismatch. expected: '%v', got: '%v'", typ, t) } rvs = reflect.Append(rvs, reflect.ValueOf(v)) } rets := reflect.ValueOf(fn).Call([]reflect.Value{rvs}) var err error if !rets[1].IsNil() { err = rets[1].Interface().(error) } return rets[0].Interface(), err } } if typ.Kind() == reflect.Map { return func(vs []any) (any, error) { return mergeMap(typ, vs) } } return nil } func mergeMap(typ reflect.Type, vs []any) (any, error) { merged := reflect.MakeMap(typ) for _, v := range vs { if t := reflect.TypeOf(v); t != typ { return nil, fmt.Errorf( "(values merge map) field type mismatch. expected: '%v', got: '%v'", typ, t) } iter := reflect.ValueOf(v).MapRange() for iter.Next() { key, val := iter.Key(), iter.Value() if merged.MapIndex(key).IsValid() { return nil, fmt.Errorf("(values merge map) duplicated key ('%v') found", key.Interface()) } merged.SetMapIndex(key, val) } } return merged.Interface(), nil } ================================================ FILE: internal/mock/adk/Agent_mock.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Code generated by MockGen. DO NOT EDIT. // Source: interface.go // // Generated by this command: // // mockgen -destination ../internal/mock/adk/Agent_mock.go --package adk -source interface.go // // Package adk is a generated GoMock package. package adk import ( context "context" reflect "reflect" adk "github.com/cloudwego/eino/adk" gomock "go.uber.org/mock/gomock" ) // MockAgent is a mock of Agent interface. type MockAgent struct { ctrl *gomock.Controller recorder *MockAgentMockRecorder isgomock struct{} } // MockAgentMockRecorder is the mock recorder for MockAgent. type MockAgentMockRecorder struct { mock *MockAgent } // NewMockAgent creates a new mock instance. func NewMockAgent(ctrl *gomock.Controller) *MockAgent { mock := &MockAgent{ctrl: ctrl} mock.recorder = &MockAgentMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockAgent) EXPECT() *MockAgentMockRecorder { return m.recorder } // Description mocks base method. func (m *MockAgent) Description(ctx context.Context) string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Description", ctx) ret0, _ := ret[0].(string) return ret0 } // Description indicates an expected call of Description. func (mr *MockAgentMockRecorder) Description(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Description", reflect.TypeOf((*MockAgent)(nil).Description), ctx) } // Name mocks base method. func (m *MockAgent) Name(ctx context.Context) string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Name", ctx) ret0, _ := ret[0].(string) return ret0 } // Name indicates an expected call of Name. func (mr *MockAgentMockRecorder) Name(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockAgent)(nil).Name), ctx) } // Run mocks base method. func (m *MockAgent) Run(ctx context.Context, input *adk.AgentInput, options ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { m.ctrl.T.Helper() varargs := []any{ctx, input} for _, a := range options { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Run", varargs...) ret0, _ := ret[0].(*adk.AsyncIterator[*adk.AgentEvent]) return ret0 } // Run indicates an expected call of Run. func (mr *MockAgentMockRecorder) Run(ctx, input any, options ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, input}, options...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockAgent)(nil).Run), varargs...) } // MockOnSubAgents is a mock of OnSubAgents interface. type MockOnSubAgents struct { ctrl *gomock.Controller recorder *MockOnSubAgentsMockRecorder isgomock struct{} } // MockOnSubAgentsMockRecorder is the mock recorder for MockOnSubAgents. type MockOnSubAgentsMockRecorder struct { mock *MockOnSubAgents } // NewMockOnSubAgents creates a new mock instance. func NewMockOnSubAgents(ctrl *gomock.Controller) *MockOnSubAgents { mock := &MockOnSubAgents{ctrl: ctrl} mock.recorder = &MockOnSubAgentsMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockOnSubAgents) EXPECT() *MockOnSubAgentsMockRecorder { return m.recorder } // OnDisallowTransferToParent mocks base method. func (m *MockOnSubAgents) OnDisallowTransferToParent(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OnDisallowTransferToParent", ctx) ret0, _ := ret[0].(error) return ret0 } // OnDisallowTransferToParent indicates an expected call of OnDisallowTransferToParent. func (mr *MockOnSubAgentsMockRecorder) OnDisallowTransferToParent(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDisallowTransferToParent", reflect.TypeOf((*MockOnSubAgents)(nil).OnDisallowTransferToParent), ctx) } // OnSetAsSubAgent mocks base method. func (m *MockOnSubAgents) OnSetAsSubAgent(ctx context.Context, parent adk.Agent) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OnSetAsSubAgent", ctx, parent) ret0, _ := ret[0].(error) return ret0 } // OnSetAsSubAgent indicates an expected call of OnSetAsSubAgent. func (mr *MockOnSubAgentsMockRecorder) OnSetAsSubAgent(ctx, parent any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSetAsSubAgent", reflect.TypeOf((*MockOnSubAgents)(nil).OnSetAsSubAgent), ctx, parent) } // OnSetSubAgents mocks base method. func (m *MockOnSubAgents) OnSetSubAgents(ctx context.Context, subAgents []adk.Agent) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OnSetSubAgents", ctx, subAgents) ret0, _ := ret[0].(error) return ret0 } // OnSetSubAgents indicates an expected call of OnSetSubAgents. func (mr *MockOnSubAgentsMockRecorder) OnSetSubAgents(ctx, subAgents any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSetSubAgents", reflect.TypeOf((*MockOnSubAgents)(nil).OnSetSubAgents), ctx, subAgents) } ================================================ FILE: internal/mock/components/document/document_mock.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Code generated by MockGen. DO NOT EDIT. // Source: interface.go // Package document is a generated GoMock package. package document import ( context "context" reflect "reflect" document "github.com/cloudwego/eino/components/document" schema "github.com/cloudwego/eino/schema" gomock "go.uber.org/mock/gomock" ) // MockLoader is a mock of Loader interface. type MockLoader struct { ctrl *gomock.Controller recorder *MockLoaderMockRecorder } // MockLoaderMockRecorder is the mock recorder for MockLoader. type MockLoaderMockRecorder struct { mock *MockLoader } // NewMockLoader creates a new mock instance. func NewMockLoader(ctrl *gomock.Controller) *MockLoader { mock := &MockLoader{ctrl: ctrl} mock.recorder = &MockLoaderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockLoader) EXPECT() *MockLoaderMockRecorder { return m.recorder } // Load mocks base method. func (m *MockLoader) Load(ctx context.Context, src document.Source, opts ...document.LoaderOption) ([]*schema.Document, error) { m.ctrl.T.Helper() varargs := []interface{}{ctx, src} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Load", varargs...) ret0, _ := ret[0].([]*schema.Document) ret1, _ := ret[1].(error) return ret0, ret1 } // Load indicates an expected call of Load. func (mr *MockLoaderMockRecorder) Load(ctx, src interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, src}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockLoader)(nil).Load), varargs...) } // MockTransformer is a mock of Transformer interface. type MockTransformer struct { ctrl *gomock.Controller recorder *MockTransformerMockRecorder } // MockTransformerMockRecorder is the mock recorder for MockTransformer. type MockTransformerMockRecorder struct { mock *MockTransformer } // NewMockTransformer creates a new mock instance. func NewMockTransformer(ctrl *gomock.Controller) *MockTransformer { mock := &MockTransformer{ctrl: ctrl} mock.recorder = &MockTransformerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockTransformer) EXPECT() *MockTransformerMockRecorder { return m.recorder } // Transform mocks base method. func (m *MockTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) { m.ctrl.T.Helper() varargs := []interface{}{ctx, src} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Transform", varargs...) ret0, _ := ret[0].([]*schema.Document) ret1, _ := ret[1].(error) return ret0, ret1 } // Transform indicates an expected call of Transform. func (mr *MockTransformerMockRecorder) Transform(ctx, src interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, src}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transform", reflect.TypeOf((*MockTransformer)(nil).Transform), varargs...) } ================================================ FILE: internal/mock/components/embedding/Embedding_mock.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Code generated by MockGen. DO NOT EDIT. // Source: interface.go // // Generated by this command: // // mockgen -destination ../../internal/mock/components/embedding/Embedding_mock.go --package embedding -source interface.go // // Package embedding is a generated GoMock package. package embedding import ( context "context" reflect "reflect" embedding "github.com/cloudwego/eino/components/embedding" gomock "go.uber.org/mock/gomock" ) // MockEmbedder is a mock of Embedder interface. type MockEmbedder struct { ctrl *gomock.Controller recorder *MockEmbedderMockRecorder } // MockEmbedderMockRecorder is the mock recorder for MockEmbedder. type MockEmbedderMockRecorder struct { mock *MockEmbedder } // NewMockEmbedder creates a new mock instance. func NewMockEmbedder(ctrl *gomock.Controller) *MockEmbedder { mock := &MockEmbedder{ctrl: ctrl} mock.recorder = &MockEmbedderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockEmbedder) EXPECT() *MockEmbedderMockRecorder { return m.recorder } // EmbedStrings mocks base method. func (m *MockEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { m.ctrl.T.Helper() varargs := []any{ctx, texts} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "EmbedStrings", varargs...) ret0, _ := ret[0].([][]float64) ret1, _ := ret[1].(error) return ret0, ret1 } // EmbedStrings indicates an expected call of EmbedStrings. func (mr *MockEmbedderMockRecorder) EmbedStrings(ctx, texts any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, texts}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmbedStrings", reflect.TypeOf((*MockEmbedder)(nil).EmbedStrings), varargs...) } ================================================ FILE: internal/mock/components/indexer/indexer_mock.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Code generated by MockGen. DO NOT EDIT. // Source: interface.go // // Generated by this command: // // mockgen -destination ../../internal/mock/components/indexer/indexer_mock.go --package indexer -source interface.go // // Package indexer is a generated GoMock package. package indexer import ( context "context" reflect "reflect" indexer "github.com/cloudwego/eino/components/indexer" schema "github.com/cloudwego/eino/schema" gomock "go.uber.org/mock/gomock" ) // MockIndexer is a mock of Indexer interface. type MockIndexer struct { ctrl *gomock.Controller recorder *MockIndexerMockRecorder } // MockIndexerMockRecorder is the mock recorder for MockIndexer. type MockIndexerMockRecorder struct { mock *MockIndexer } // NewMockIndexer creates a new mock instance. func NewMockIndexer(ctrl *gomock.Controller) *MockIndexer { mock := &MockIndexer{ctrl: ctrl} mock.recorder = &MockIndexerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockIndexer) EXPECT() *MockIndexerMockRecorder { return m.recorder } // Store mocks base method. func (m *MockIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { m.ctrl.T.Helper() varargs := []any{ctx, docs} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Store", varargs...) ret0, _ := ret[0].([]string) ret1, _ := ret[1].(error) return ret0, ret1 } // Store indicates an expected call of Store. func (mr *MockIndexerMockRecorder) Store(ctx, docs any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, docs}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockIndexer)(nil).Store), varargs...) } ================================================ FILE: internal/mock/components/model/ChatModel_mock.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Code generated by MockGen. DO NOT EDIT. // Source: interface.go // // Generated by this command: // // mockgen -destination ../../internal/mock/components/model/ChatModel_mock.go --package model -source interface.go // // Package model is a generated GoMock package. package model import ( context "context" reflect "reflect" model "github.com/cloudwego/eino/components/model" schema "github.com/cloudwego/eino/schema" gomock "go.uber.org/mock/gomock" ) // MockBaseChatModel is a mock of BaseChatModel interface. type MockBaseChatModel struct { ctrl *gomock.Controller recorder *MockBaseChatModelMockRecorder isgomock struct{} } // MockBaseChatModelMockRecorder is the mock recorder for MockBaseChatModel. type MockBaseChatModelMockRecorder struct { mock *MockBaseChatModel } // NewMockBaseChatModel creates a new mock instance. func NewMockBaseChatModel(ctrl *gomock.Controller) *MockBaseChatModel { mock := &MockBaseChatModel{ctrl: ctrl} mock.recorder = &MockBaseChatModelMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockBaseChatModel) EXPECT() *MockBaseChatModelMockRecorder { return m.recorder } // Generate mocks base method. func (m *MockBaseChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { m.ctrl.T.Helper() varargs := []any{ctx, input} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Generate", varargs...) ret0, _ := ret[0].(*schema.Message) ret1, _ := ret[1].(error) return ret0, ret1 } // Generate indicates an expected call of Generate. func (mr *MockBaseChatModelMockRecorder) Generate(ctx, input any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, input}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Generate", reflect.TypeOf((*MockBaseChatModel)(nil).Generate), varargs...) } // Stream mocks base method. func (m *MockBaseChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { m.ctrl.T.Helper() varargs := []any{ctx, input} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Stream", varargs...) ret0, _ := ret[0].(*schema.StreamReader[*schema.Message]) ret1, _ := ret[1].(error) return ret0, ret1 } // Stream indicates an expected call of Stream. func (mr *MockBaseChatModelMockRecorder) Stream(ctx, input any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, input}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stream", reflect.TypeOf((*MockBaseChatModel)(nil).Stream), varargs...) } // MockChatModel is a mock of ChatModel interface. type MockChatModel struct { ctrl *gomock.Controller recorder *MockChatModelMockRecorder isgomock struct{} } // MockChatModelMockRecorder is the mock recorder for MockChatModel. type MockChatModelMockRecorder struct { mock *MockChatModel } // NewMockChatModel creates a new mock instance. func NewMockChatModel(ctrl *gomock.Controller) *MockChatModel { mock := &MockChatModel{ctrl: ctrl} mock.recorder = &MockChatModelMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockChatModel) EXPECT() *MockChatModelMockRecorder { return m.recorder } // BindTools mocks base method. func (m *MockChatModel) BindTools(tools []*schema.ToolInfo) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BindTools", tools) ret0, _ := ret[0].(error) return ret0 } // BindTools indicates an expected call of BindTools. func (mr *MockChatModelMockRecorder) BindTools(tools any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindTools", reflect.TypeOf((*MockChatModel)(nil).BindTools), tools) } // Generate mocks base method. func (m *MockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { m.ctrl.T.Helper() varargs := []any{ctx, input} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Generate", varargs...) ret0, _ := ret[0].(*schema.Message) ret1, _ := ret[1].(error) return ret0, ret1 } // Generate indicates an expected call of Generate. func (mr *MockChatModelMockRecorder) Generate(ctx, input any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, input}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Generate", reflect.TypeOf((*MockChatModel)(nil).Generate), varargs...) } // Stream mocks base method. func (m *MockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { m.ctrl.T.Helper() varargs := []any{ctx, input} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Stream", varargs...) ret0, _ := ret[0].(*schema.StreamReader[*schema.Message]) ret1, _ := ret[1].(error) return ret0, ret1 } // Stream indicates an expected call of Stream. func (mr *MockChatModelMockRecorder) Stream(ctx, input any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, input}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stream", reflect.TypeOf((*MockChatModel)(nil).Stream), varargs...) } // MockToolCallingChatModel is a mock of ToolCallingChatModel interface. type MockToolCallingChatModel struct { ctrl *gomock.Controller recorder *MockToolCallingChatModelMockRecorder isgomock struct{} } // MockToolCallingChatModelMockRecorder is the mock recorder for MockToolCallingChatModel. type MockToolCallingChatModelMockRecorder struct { mock *MockToolCallingChatModel } // NewMockToolCallingChatModel creates a new mock instance. func NewMockToolCallingChatModel(ctrl *gomock.Controller) *MockToolCallingChatModel { mock := &MockToolCallingChatModel{ctrl: ctrl} mock.recorder = &MockToolCallingChatModelMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockToolCallingChatModel) EXPECT() *MockToolCallingChatModelMockRecorder { return m.recorder } // Generate mocks base method. func (m *MockToolCallingChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { m.ctrl.T.Helper() varargs := []any{ctx, input} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Generate", varargs...) ret0, _ := ret[0].(*schema.Message) ret1, _ := ret[1].(error) return ret0, ret1 } // Generate indicates an expected call of Generate. func (mr *MockToolCallingChatModelMockRecorder) Generate(ctx, input any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, input}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Generate", reflect.TypeOf((*MockToolCallingChatModel)(nil).Generate), varargs...) } // Stream mocks base method. func (m *MockToolCallingChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { m.ctrl.T.Helper() varargs := []any{ctx, input} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Stream", varargs...) ret0, _ := ret[0].(*schema.StreamReader[*schema.Message]) ret1, _ := ret[1].(error) return ret0, ret1 } // Stream indicates an expected call of Stream. func (mr *MockToolCallingChatModelMockRecorder) Stream(ctx, input any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, input}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stream", reflect.TypeOf((*MockToolCallingChatModel)(nil).Stream), varargs...) } // WithTools mocks base method. func (m *MockToolCallingChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WithTools", tools) ret0, _ := ret[0].(model.ToolCallingChatModel) ret1, _ := ret[1].(error) return ret0, ret1 } // WithTools indicates an expected call of WithTools. func (mr *MockToolCallingChatModelMockRecorder) WithTools(tools any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTools", reflect.TypeOf((*MockToolCallingChatModel)(nil).WithTools), tools) } ================================================ FILE: internal/mock/components/retriever/retriever_mock.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Code generated by MockGen. DO NOT EDIT. // Source: interface.go // // Generated by this command: // // mockgen -destination ../../internal/mock/components/retriever/retriever_mock.go --package retriever -source interface.go // // Package retriever is a generated GoMock package. package retriever import ( context "context" reflect "reflect" retriever "github.com/cloudwego/eino/components/retriever" schema "github.com/cloudwego/eino/schema" gomock "go.uber.org/mock/gomock" ) // MockRetriever is a mock of Retriever interface. type MockRetriever struct { ctrl *gomock.Controller recorder *MockRetrieverMockRecorder } // MockRetrieverMockRecorder is the mock recorder for MockRetriever. type MockRetrieverMockRecorder struct { mock *MockRetriever } // NewMockRetriever creates a new mock instance. func NewMockRetriever(ctrl *gomock.Controller) *MockRetriever { mock := &MockRetriever{ctrl: ctrl} mock.recorder = &MockRetrieverMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockRetriever) EXPECT() *MockRetrieverMockRecorder { return m.recorder } // Retrieve mocks base method. func (m *MockRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { m.ctrl.T.Helper() varargs := []any{ctx, query} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Retrieve", varargs...) ret0, _ := ret[0].([]*schema.Document) ret1, _ := ret[1].(error) return ret0, ret1 } // Retrieve indicates an expected call of Retrieve. func (mr *MockRetrieverMockRecorder) Retrieve(ctx, query any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, query}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retrieve", reflect.TypeOf((*MockRetriever)(nil).Retrieve), varargs...) } ================================================ FILE: internal/mock/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package mock provides mock implementations for testing purposes. // // This package aims to provide mock implementations for interfaces in the components package, // making it easier to use in testing environments. It includes mock implementations for // various core components such as retrievers, tools, message handlers, and graph runners. // // Directory Structure: // - components/: Contains mock implementations for various components // - retriever/: Provides mock implementation for the Retriever interface // - retriever_mock.go: Mock implementation for document retrieval // - tool/: Mock implementations for tool-related interfaces // - message/: Mock implementations for message handling components // - graph/: Mock implementations for graph execution components // - stream/: Mock implementations for streaming components // // Usage: // These mock implementations are primarily used in unit tests and integration tests, // allowing developers to conduct tests without depending on actual external services. // Each mock component strictly follows the contract of its corresponding interface // while providing controllable behaviors and results. // // Examples: // // - Using mock retriever: // retriever := mock.NewMockRetriever() // // Configure retriever behavior // // - Using mock tool: // tool := mock.NewMockTool() // // Configure tool behavior // // - Using mock graph runner: // runner := mock.NewMockGraphRunner() // // Configure runner behavior package mock ================================================ FILE: internal/safe/panic.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package safe import ( "fmt" ) type panicErr struct { info any stack []byte } func (p *panicErr) Error() string { return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack)) } // NewPanicErr creates a new panic error. // panicErr is a wrapper of panic info and stack trace. // it implements the error interface, can print error message of info and stack trace. func NewPanicErr(info any, stack []byte) error { return &panicErr{ info: info, stack: stack, } } ================================================ FILE: internal/safe/panic_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package safe import ( "testing" "github.com/stretchr/testify/assert" ) func TestPanicErr(t *testing.T) { err := NewPanicErr("info", []byte("stack")) assert.Equal(t, "panic error: info, \nstack: stack", err.Error()) } ================================================ FILE: internal/serialization/serialization.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package serialization import ( "encoding/json" "fmt" "reflect" "github.com/bytedance/sonic" ) var m = map[string]reflect.Type{} var rm = map[reflect.Type]string{} func init() { _ = GenericRegister[int]("_eino_int") _ = GenericRegister[int8]("_eino_int8") _ = GenericRegister[int16]("_eino_int16") _ = GenericRegister[int32]("_eino_int32") _ = GenericRegister[int64]("_eino_int64") _ = GenericRegister[uint]("_eino_uint") _ = GenericRegister[uint8]("_eino_uint8") _ = GenericRegister[uint16]("_eino_uint16") _ = GenericRegister[uint32]("_eino_uint32") _ = GenericRegister[uint64]("_eino_uint64") _ = GenericRegister[float32]("_eino_float32") _ = GenericRegister[float64]("_eino_float64") _ = GenericRegister[complex64]("_eino_complex64") _ = GenericRegister[complex128]("_eino_complex128") _ = GenericRegister[uintptr]("_eino_uintptr") _ = GenericRegister[bool]("_eino_bool") _ = GenericRegister[string]("_eino_string") _ = GenericRegister[any]("_eino_any") } func GenericRegister[T any](key string) error { t := reflect.TypeOf((*T)(nil)).Elem() for t.Kind() == reflect.Ptr { t = t.Elem() } if nt, ok := m[key]; ok { return fmt.Errorf("key[%s] already registered to %s", key, nt.String()) } if nk, ok := rm[t]; ok { return fmt.Errorf("type[%s] already registered to %s", t.String(), nk) } m[key] = t rm[t] = key return nil } type InternalSerializer struct{} func (i *InternalSerializer) Marshal(v any) ([]byte, error) { is, err := internalMarshal(v, nil) if err != nil { return nil, err } return sonic.Marshal(is) } func (i *InternalSerializer) Unmarshal(data []byte, v any) error { val, err := unmarshal(data, reflect.TypeOf(v)) if err != nil { return fmt.Errorf("failed to unmarshal: %w", err) } rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("failed to unmarshal: value must be a non-nil pointer") } target := rv.Elem() if !target.CanSet() { return fmt.Errorf("failed to unmarshal: output value must be settable") } if val == nil { target.Set(reflect.Zero(target.Type())) return nil } source := reflect.ValueOf(val) var set func(target, source reflect.Value) bool set = func(target, source reflect.Value) bool { if !source.IsValid() { target.Set(reflect.Zero(target.Type())) return true } if source.Type().AssignableTo(target.Type()) { target.Set(source) return true } if target.Kind() == reflect.Ptr { if target.IsNil() { if !target.CanSet() { return false } target.Set(reflect.New(target.Type().Elem())) } return set(target.Elem(), source) } if source.Kind() == reflect.Ptr { if source.IsNil() { target.Set(reflect.Zero(target.Type())) return true } return set(target, source.Elem()) } if source.Type().ConvertibleTo(target.Type()) { target.Set(source.Convert(target.Type())) return true } return false } if set(target, source) { return nil } return fmt.Errorf("failed to unmarshal: cannot assign %s to %s", reflect.TypeOf(val), target.Type()) } func unmarshal(data []byte, t reflect.Type) (any, error) { is := &internalStruct{} err := sonic.Unmarshal(data, is) if err != nil { return nil, err } return internalUnmarshal(is, t) } type internalStruct struct { Type *valueType `json:",omitempty"` JSONValue json.RawMessage `json:",omitempty"` // map or struct // in map, the key is the serialized map key anyway todo: if key is string, don't serialize // in struct, the key is the original field name MapValues map[string]*internalStruct `json:",omitempty"` // slice SliceValues []*internalStruct `json:",omitempty"` } type valueType struct { PointerNum uint32 `json:",omitempty"` SimpleType string `json:",omitempty"` StructType string `json:",omitempty"` MapKeyType *valueType `json:",omitempty"` MapValueType *valueType `json:",omitempty"` SliceValueType *valueType `json:",omitempty"` } func extractType(t reflect.Type) (*valueType, error) { ret := &valueType{} for t.Kind() == reflect.Ptr { ret.PointerNum += 1 t = t.Elem() } var err error if t.Kind() == reflect.Map { ret.MapKeyType, err = extractType(t.Key()) if err != nil { return nil, err } ret.MapValueType, err = extractType(t.Elem()) if err != nil { return nil, err } } else if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { ret.SliceValueType, err = extractType(t.Elem()) if err != nil { return nil, err } } else { key, ok := rm[t] if !ok { return ret, fmt.Errorf("unknown type: %s", t.String()) } ret.SimpleType = key } return ret, nil } func restoreType(vt *valueType) (reflect.Type, error) { if vt.SimpleType != "" { rt, ok := m[vt.SimpleType] if !ok { return nil, fmt.Errorf("unknown type: %s", vt.SimpleType) } return resolvePointerNum(vt.PointerNum, rt), nil } if vt.StructType != "" { rt, ok := m[vt.StructType] if !ok { return nil, fmt.Errorf("unknown type: %s", vt.StructType) } return resolvePointerNum(vt.PointerNum, rt), nil } if vt.MapKeyType != nil { rkt, err := restoreType(vt.MapKeyType) if err != nil { return nil, err } rvt, err := restoreType(vt.MapValueType) if err != nil { return nil, err } return resolvePointerNum(vt.PointerNum, reflect.MapOf(rkt, rvt)), nil } if vt.SliceValueType != nil { rt, err := restoreType(vt.SliceValueType) if err != nil { return nil, err } return resolvePointerNum(vt.PointerNum, reflect.SliceOf(rt)), nil } return nil, fmt.Errorf("empty value") } func internalMarshal(v any, fieldType reflect.Type) (*internalStruct, error) { if v == nil || (reflect.ValueOf(v).IsZero() && fieldType != nil && fieldType.Kind() != reflect.Interface) { return nil, nil } ret := &internalStruct{} rv := reflect.ValueOf(v) rt := rv.Type() typeUnspecific := fieldType == nil || fieldType.Kind() == reflect.Interface var pointerNum uint32 for rt.Kind() == reflect.Ptr { pointerNum++ if !rv.IsNil() { rv = rv.Elem() rt = rt.Elem() continue } for rt.Kind() == reflect.Ptr { rt = rt.Elem() } if typeUnspecific { // need type registered key, ok := rm[rt] if !ok { return nil, fmt.Errorf("unknown type: %v", rt) } ret.Type = &valueType{ PointerNum: pointerNum, SimpleType: key, } } ret.JSONValue = json.RawMessage("null") return ret, nil } switch rt.Kind() { case reflect.Struct: if typeUnspecific { // need type registered key, ok := rm[rt] if !ok { return nil, fmt.Errorf("unknown type: %v", rt) } if checkMarshaler(rt) { ret.Type = &valueType{ PointerNum: pointerNum, SimpleType: key, } } else { ret.Type = &valueType{ PointerNum: pointerNum, StructType: key, } } } if checkMarshaler(rt) { jsonBytes, err := json.Marshal(rv.Interface()) if err != nil { return nil, err } ret.JSONValue = jsonBytes return ret, nil } ret.MapValues = make(map[string]*internalStruct) for i := 0; i < rt.NumField(); i++ { field := rt.Field(i) // only handle exported fields if field.PkgPath == "" { k := field.Name v := rv.Field(i) internalValue, err := internalMarshal(v.Interface(), field.Type) if err != nil { return nil, err } ret.MapValues[k] = internalValue } } return ret, nil case reflect.Map: if typeUnspecific { var err error ret.Type = &valueType{ PointerNum: pointerNum, } // map key type ret.Type.MapKeyType, err = extractType(rt.Key()) if err != nil { return nil, err } // map value type ret.Type.MapValueType, err = extractType(rt.Elem()) if err != nil { return nil, err } } ret.MapValues = make(map[string]*internalStruct) iter := rv.MapRange() for iter.Next() { k := iter.Key() v := iter.Value() internalValue, err := internalMarshal(v.Interface(), rt.Elem()) if err != nil { return nil, err } keyStr, err := sonic.MarshalString(k.Interface()) if err != nil { return nil, fmt.Errorf("marshaling map key[%v] fail: %v", k.Interface(), err) } ret.MapValues[keyStr] = internalValue } return ret, nil case reflect.Slice, reflect.Array: if typeUnspecific { var err error ret.Type = &valueType{PointerNum: pointerNum} ret.Type.SliceValueType, err = extractType(rt.Elem()) if err != nil { return nil, err } } length := rv.Len() ret.SliceValues = make([]*internalStruct, length) for i := 0; i < length; i++ { internalValue, err := internalMarshal(rv.Index(i).Interface(), rt.Elem()) if err != nil { return nil, err } ret.SliceValues[i] = internalValue } return ret, nil default: if typeUnspecific { key, ok := rm[rv.Type()] if !ok { return nil, fmt.Errorf("unknown type: %v", rt) } ret.Type = &valueType{ PointerNum: pointerNum, SimpleType: key, } } jsonBytes, err := json.Marshal(rv.Interface()) if err != nil { return nil, err } ret.JSONValue = jsonBytes return ret, nil } } func internalUnmarshal(v *internalStruct, typ reflect.Type) (any, error) { if v == nil { return nil, nil } if v.Type == nil { // specific type if checkMarshaler(typ) { pv := reflect.New(typ) err := json.Unmarshal(v.JSONValue, pv.Interface()) if err != nil { return nil, err } return pv.Elem().Interface(), nil } return internalSpecificTypeUnmarshal(v, typ) } if len(v.Type.SimpleType) != 0 { // based type t, ok := m[v.Type.SimpleType] if !ok { return nil, fmt.Errorf("unknown type key: %v", v.Type) } pResult := reflect.New(resolvePointerNum(v.Type.PointerNum, t)) err := sonic.Unmarshal(v.JSONValue, pResult.Interface()) if err != nil { return nil, fmt.Errorf("unmarshal type[%s] fail: %v, data: %s", t.String(), err, string(v.JSONValue)) } return pResult.Elem().Interface(), nil } if len(v.Type.StructType) > 0 { // struct rt, ok := m[v.Type.StructType] if !ok { return nil, fmt.Errorf("unknown type key: %v", v.Type.StructType) } result, dResult := createValueFromType(resolvePointerNum(v.Type.PointerNum, rt)) err := setStructFields(dResult, v.MapValues) if err != nil { return nil, err } return result.Interface(), nil } if v.Type.MapKeyType != nil { // map rkt, err := restoreType(v.Type.MapKeyType) if err != nil { return nil, err } rvt, err := restoreType(v.Type.MapValueType) if err != nil { return nil, err } result, dResult := createValueFromType(reflect.MapOf(rkt, rvt)) err = setMapKVs(dResult, v.MapValues) if err != nil { return nil, err } return result.Interface(), nil } // slice rvt, err := restoreType(v.Type.SliceValueType) if err != nil { return nil, err } result, dResult := createValueFromType(reflect.SliceOf(rvt)) err = setSliceElems(dResult, v.SliceValues) if err != nil { return nil, err } return result.Interface(), nil } func internalSpecificTypeUnmarshal(is *internalStruct, typ reflect.Type) (any, error) { _, dtyp := derefPointerNum(typ) result, dResult := createValueFromType(typ) if dtyp.Kind() == reflect.Struct { err := setStructFields(dResult, is.MapValues) if err != nil { return nil, err } return result.Interface(), nil } else if dtyp.Kind() == reflect.Map { err := setMapKVs(dResult, is.MapValues) if err != nil { return nil, err } return result.Interface(), nil } else if dtyp.Kind() == reflect.Array || dtyp.Kind() == reflect.Slice { err := setSliceElems(dResult, is.SliceValues) if err != nil { return nil, err } return result.Interface(), nil } // simple type v := reflect.New(typ) err := sonic.Unmarshal(is.JSONValue, v.Interface()) if err != nil { return nil, fmt.Errorf("unmarshal type[%s] fail: %v", typ.String(), err) } return v.Elem().Interface(), nil } func setSliceElems(dResult reflect.Value, values []*internalStruct) error { t := dResult.Type() // Handle arrays differently from slices // Arrays have fixed size and cannot use reflect.Append if dResult.Kind() == reflect.Array { for i, internalValue := range values { if i >= dResult.Len() { return fmt.Errorf("array index out of bounds: trying to set index %d in array of length %d", i, dResult.Len()) } value, err := internalUnmarshal(internalValue, t.Elem()) if err != nil { return fmt.Errorf("unmarshal array[%s] element %d fail: %v", t.Elem(), i, err) } if value == nil { dResult.Index(i).Set(reflect.Zero(t.Elem())) } else { dResult.Index(i).Set(reflect.ValueOf(value)) } } return nil } // For slices, use Append as before for _, internalValue := range values { value, err := internalUnmarshal(internalValue, t.Elem()) if err != nil { return fmt.Errorf("unmarshal slice[%s] fail: %v", t.Elem(), err) } if value == nil { // empty value dResult.Set(reflect.Append(dResult, reflect.New(t.Elem()).Elem())) } else { dResult.Set(reflect.Append(dResult, reflect.ValueOf(value))) } } return nil } func setMapKVs(dResult reflect.Value, values map[string]*internalStruct) error { t := dResult.Type() for marshaledMapKey, internalValue := range values { prkv := reflect.New(t.Key()) err := sonic.UnmarshalString(marshaledMapKey, prkv.Interface()) if err != nil { return fmt.Errorf("unmarshal map key[%v] to type[%s] fail: %v", marshaledMapKey, t.Key(), err) } value, err := internalUnmarshal(internalValue, t.Elem()) if err != nil { return fmt.Errorf("unmarshal map value fail: %v", err) } if value == nil { dResult.SetMapIndex(prkv.Elem(), reflect.New(t.Elem()).Elem()) } else { dResult.SetMapIndex(prkv.Elem(), reflect.ValueOf(value)) } } return nil } func setStructFields(dResult reflect.Value, values map[string]*internalStruct) error { t := dResult.Type() for k, internalValue := range values { sf, ok := t.FieldByName(k) if !ok { continue } value, err := internalUnmarshal(internalValue, sf.Type) if err != nil { return fmt.Errorf("unmarshal map field[%v] fail: %v", k, err) } err = setStructField(t, dResult, k, value) if err != nil { return err } } return nil } func setStructField(t reflect.Type, s reflect.Value, fieldName string, val any) error { field := s.FieldByName(fieldName) if !field.CanSet() { return fmt.Errorf("unmarshal map fail, can not set field %v", fieldName) } if val == nil { rft, ok := t.FieldByName(fieldName) if !ok { return fmt.Errorf("unmarshal map fail, cannot find field: %v", fieldName) } field.Set(reflect.New(rft.Type).Elem()) } else { field.Set(reflect.ValueOf(val)) } return nil } func resolvePointerNum(pointerNum uint32, t reflect.Type) reflect.Type { for i := uint32(0); i < pointerNum; i++ { t = reflect.PointerTo(t) } return t } func derefPointerNum(t reflect.Type) (uint32, reflect.Type) { var ptrCount uint32 = 0 for t != nil && t.Kind() == reflect.Ptr { t = t.Elem() ptrCount++ } return ptrCount, t } func createValueFromType(t reflect.Type) (value reflect.Value, derefValue reflect.Value) { value = reflect.New(t).Elem() derefValue = value for derefValue.Kind() == reflect.Ptr { if derefValue.IsNil() { derefValue.Set(reflect.New(derefValue.Type().Elem())) } derefValue = derefValue.Elem() } if derefValue.Kind() == reflect.Map && derefValue.IsNil() { derefValue.Set(reflect.MakeMap(derefValue.Type())) } // Use Len() == 0 instead of IsNil() for slices to avoid panic // IsNil() can panic on uninitialized slice values created via reflect.New().Elem() if derefValue.Kind() == reflect.Slice { if derefValue.Len() == 0 && derefValue.Cap() == 0 { derefValue.Set(reflect.MakeSlice(derefValue.Type(), 0, 0)) } } // Arrays cannot be nil and don't need initialization return value, derefValue } var marshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() var unmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() func checkMarshaler(t reflect.Type) bool { for t.Kind() == reflect.Ptr { t = t.Elem() } if (t.Implements(marshalerType) || reflect.PointerTo(t).Implements(marshalerType)) && (t.Implements(unmarshalerType) || reflect.PointerTo(t).Implements(unmarshalerType)) { return true } return false } ================================================ FILE: internal/serialization/serialization_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package serialization import ( "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type myInterface interface { Method() } type myStruct struct { A string } func (m *myStruct) Method() {} type myStruct2 struct { A any B myInterface C map[string]**myStruct D map[myStruct]any E []any f string G myStruct3 H *myStruct4 I []*myStruct3 J map[string]myStruct3 K myStruct4 L []*myStruct4 M map[string]myStruct4 } type myStruct3 struct { FieldA string } type myStruct4 struct { FieldA string } func (m *myStruct4) UnmarshalJSON(bytes []byte) error { m.FieldA = string(bytes) return nil } func (m myStruct4) MarshalJSON() ([]byte, error) { return []byte(m.FieldA), nil } func TestSerialization(t *testing.T) { _ = GenericRegister[myStruct]("myStruct") _ = GenericRegister[myStruct2]("myStruct2") _ = GenericRegister[myInterface]("myInterface") ms := myStruct{A: "test"} pms := &ms pointerOfPointerOfMyStruct := &pms ms1 := myStruct{A: "1"} ms2 := myStruct{A: "2"} ms3 := myStruct{A: "3"} ms4 := myStruct{A: "4"} values := []any{ 10, "test", ms, pms, pointerOfPointerOfMyStruct, myInterface(pms), []int{1, 2, 3}, []any{1, "test"}, []myInterface{nil, &myStruct{A: "1"}, &myStruct{A: "2"}}, map[string]string{"123": "123", "abc": "abc"}, map[string]myInterface{"1": nil, "2": pms}, map[string]any{"123": 1, "abc": &myStruct{A: "1"}, "bcd": nil}, map[myStruct]any{ ms1: 1, ms2: &myStruct{ A: "2", }, ms3: nil, ms4: []any{ 1, pointerOfPointerOfMyStruct, "123", &myStruct{ A: "1", }, nil, map[myStruct]any{ ms1: 1, ms2: nil, }, }, }, myStruct2{ A: "123", B: &myStruct{ A: "test", }, C: map[string]**myStruct{ "a": pointerOfPointerOfMyStruct, }, D: map[myStruct]any{{"a"}: 1}, E: []any{1, "2", 3}, f: "", G: myStruct3{ FieldA: "1", }, H: nil, I: []*myStruct3{ {FieldA: "2"}, {FieldA: "3"}, }, J: map[string]myStruct3{ "1": {FieldA: "4"}, "2": {FieldA: "5"}, }, K: myStruct4{ FieldA: "1", }, L: []*myStruct4{ {FieldA: "2"}, {FieldA: "3"}, }, M: map[string]myStruct4{ "1": {FieldA: "4"}, "2": {FieldA: "5"}, }, }, map[string]map[string][]map[string][][]string{ "1": { "a": []map[string][][]string{ {"b": { {"c"}, {"d"}, }}, }, }, }, []*myStruct{}, &myStruct{}, } for _, value := range values { data, err := (&InternalSerializer{}).Marshal(value) assert.NoError(t, err) v := reflect.New(reflect.TypeOf(value)).Interface() err = (&InternalSerializer{}).Unmarshal(data, v) assert.NoError(t, err) assert.Equal(t, value, reflect.ValueOf(v).Elem().Interface()) } } type myStruct5 struct { FieldA string } func (m *myStruct5) UnmarshalJSON(bytes []byte) error { m.FieldA = "FieldA" return nil } func (m myStruct5) MarshalJSON() ([]byte, error) { return []byte("1"), nil } func TestMarshalStruct(t *testing.T) { assert.NoError(t, GenericRegister[myStruct5]("myStruct5")) s := myStruct5{FieldA: "1"} data, err := (&InternalSerializer{}).Marshal(s) assert.NoError(t, err) result := &myStruct5{} err = (&InternalSerializer{}).Unmarshal(data, result) assert.NoError(t, err) assert.Equal(t, myStruct5{FieldA: "FieldA"}, *result) ma := map[string]any{ "1": s, } data, err = (&InternalSerializer{}).Marshal(ma) assert.NoError(t, err) result2 := map[string]any{} err = (&InternalSerializer{}).Unmarshal(data, &result2) assert.NoError(t, err) assert.Equal(t, map[string]any{ "1": myStruct5{FieldA: "FieldA"}, }, result2) } type unmarshalTestStruct struct { Foo string Bar int } func init() { // Register types for the serializer to work. // This is necessary for the serializer to know how to handle custom struct types. err := GenericRegister[unmarshalTestStruct]("unmarshalTestStruct") if err != nil { panic(err) } } func TestInternalSerializer_Unmarshal(t *testing.T) { s := InternalSerializer{} t.Run("success cases", func(t *testing.T) { // Helper to create a pointer to a value, needed for the expected value in one test case. ptr := func(i int) *int { return &i } testCases := []struct { name string inputValue any outputPtr any expectedVal any }{ { name: "simple type", inputValue: 123, outputPtr: new(int), expectedVal: 123, }, { name: "struct type", inputValue: unmarshalTestStruct{Foo: "hello", Bar: 42}, outputPtr: new(unmarshalTestStruct), expectedVal: unmarshalTestStruct{Foo: "hello", Bar: 42}, }, { name: "pointer to struct", inputValue: &unmarshalTestStruct{Foo: "world", Bar: 99}, outputPtr: new(*unmarshalTestStruct), expectedVal: &unmarshalTestStruct{Foo: "world", Bar: 99}, }, { name: "unmarshal pointer to value", inputValue: &unmarshalTestStruct{Foo: "p2v", Bar: 1}, outputPtr: new(unmarshalTestStruct), expectedVal: unmarshalTestStruct{Foo: "p2v", Bar: 1}, }, { name: "unmarshal value to pointer", inputValue: unmarshalTestStruct{Foo: "v2p", Bar: 2}, outputPtr: new(*unmarshalTestStruct), expectedVal: &unmarshalTestStruct{Foo: "v2p", Bar: 2}, }, { name: "unmarshal nil pointer", inputValue: (*unmarshalTestStruct)(nil), outputPtr: &struct{ v *unmarshalTestStruct }{v: &unmarshalTestStruct{}}, // placeholder to be replaced expectedVal: (*unmarshalTestStruct)(nil), }, { name: "convertible types", inputValue: int32(42), outputPtr: new(int64), expectedVal: int64(42), }, { name: "pointer to pointer destination", inputValue: 12345, outputPtr: new(*int), expectedVal: ptr(12345), }, { name: "unmarshal to any", inputValue: unmarshalTestStruct{Foo: "any", Bar: 101}, outputPtr: new(any), expectedVal: unmarshalTestStruct{Foo: "any", Bar: 101}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { data, err := s.Marshal(tc.inputValue) require.NoError(t, err) // Special handling for the nil test case to correctly pass the pointer. if tc.name == "unmarshal nil pointer" { target := tc.outputPtr.(*struct{ v *unmarshalTestStruct }) err = s.Unmarshal(data, &target.v) require.NoError(t, err) assert.Nil(t, target.v) return } err = s.Unmarshal(data, tc.outputPtr) require.NoError(t, err) // Dereference the pointer to get the actual value for comparison. actualVal := reflect.ValueOf(tc.outputPtr).Elem().Interface() assert.Equal(t, tc.expectedVal, actualVal) }) } }) t.Run("error cases", func(t *testing.T) { data, err := s.Marshal(123) require.NoError(t, err) t.Run("destination not a pointer", func(t *testing.T) { var output int err := s.Unmarshal(data, output) require.Error(t, err) assert.Contains(t, err.Error(), "value must be a non-nil pointer") }) t.Run("destination is a nil pointer", func(t *testing.T) { var output *int // nil err := s.Unmarshal(data, output) require.Error(t, err) assert.Contains(t, err.Error(), "value must be a non-nil pointer") }) t.Run("type mismatch", func(t *testing.T) { strData, mErr := s.Marshal("i am a string") require.NoError(t, mErr) var output int err := s.Unmarshal(strData, &output) require.Error(t, err) assert.Contains(t, err.Error(), "cannot assign") }) t.Run("unconvertible types", func(t *testing.T) { intData, mErr := s.Marshal(123) require.NoError(t, mErr) var output bool err := s.Unmarshal(intData, &output) require.Error(t, err) assert.Contains(t, err.Error(), "cannot assign") }) }) } ================================================ FILE: llms.txt ================================================ # Eino > Eino is a Go-based LLM application development framework by ByteDance. > It provides component abstractions, a graph/chain orchestration engine, > streaming primitives, a callback system, and an Agent Development Kit (ADK) > for building production-grade LLM applications. ## Repositories - [eino](https://github.com/cloudwego/eino) — core framework (this repo) - [eino-ext](https://github.com/cloudwego/eino-ext) — component integrations (OpenAI, Ark, Ollama, Redis, S3, …) - [eino-examples](https://github.com/cloudwego/eino-examples) — runnable example applications ## Overview & Background - [Overview](https://www.cloudwego.io/docs/eino/overview/) - [ByteDance Eino Practice](https://www.cloudwego.io/docs/eino/overview/bytedance_eino_practice/) - [Eino Open Source](https://www.cloudwego.io/docs/eino/overview/eino_open_source/) - [Graph or Agent — when to use which](https://www.cloudwego.io/docs/eino/overview/graph_or_agent/) ## Quick Start - [Simple LLM Application](https://www.cloudwego.io/docs/eino/quick_start/simple_llm_application/) - [Agent with Tools](https://www.cloudwego.io/docs/eino/quick_start/agent_llm_with_tools/) - [Eino Cookbook](https://www.cloudwego.io/docs/eino/eino-cookbook/) ## Core Concepts — Components Components are the typed building blocks of eino pipelines. Each has a defined interface in the core repo; implementations live in eino-ext. - [Components overview](https://www.cloudwego.io/docs/eino/core_modules/components/) - [ChatModel](https://www.cloudwego.io/docs/eino/core_modules/components/chat_model_guide/) - [ChatTemplate](https://www.cloudwego.io/docs/eino/core_modules/components/chat_template_guide/) - [ToolsNode](https://www.cloudwego.io/docs/eino/core_modules/components/tools_node_guide/) - [How to create a Tool](https://www.cloudwego.io/docs/eino/core_modules/components/tools_node_guide/how_to_create_a_tool/) - [Retriever](https://www.cloudwego.io/docs/eino/core_modules/components/retriever_guide/) - [Indexer](https://www.cloudwego.io/docs/eino/core_modules/components/indexer_guide/) - [Embedding](https://www.cloudwego.io/docs/eino/core_modules/components/embedding_guide/) - [DocumentLoader](https://www.cloudwego.io/docs/eino/core_modules/components/document_loader_guide/) - [DocumentParser](https://www.cloudwego.io/docs/eino/core_modules/components/document_loader_guide/document_parser_interface_guide/) - [DocumentTransformer](https://www.cloudwego.io/docs/eino/core_modules/components/document_transformer_guide/) - [Lambda](https://www.cloudwego.io/docs/eino/core_modules/components/lambda_guide/) - [AgenticChatModel](https://www.cloudwego.io/docs/eino/core_modules/components/agentic_chat_model_guide/) - [AgenticChatTemplate](https://www.cloudwego.io/docs/eino/core_modules/components/agentic_chat_template_guide/) - [AgenticToolsNode](https://www.cloudwego.io/docs/eino/core_modules/components/agentic_tools_node_guide/) ## Core Concepts — Orchestration The orchestration layer composes components into executable pipelines. Chain is a linear sequence; Graph is a DAG with conditional edges; Workflow is a higher-level structured abstraction over Graph. - [Chain & Graph introduction](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/chain_graph_introduction/) - [Orchestration design principles](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/orchestration_design_principles/) - [Workflow orchestration framework](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/workflow_orchestration_framework/) - [Stream programming essentials](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/stream_programming_essentials/) - [Callback system](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/callback_manual/) - [CallOption capabilities](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/call_option_capabilities/) - [Checkpoint & interrupt/resume](https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/checkpoint_interrupt/) ## Core Concepts — Flow Integration - [ReAct agent](https://www.cloudwego.io/docs/eino/core_modules/flow_integration_components/react_agent_manual/) - [Multi-agent hosting](https://www.cloudwego.io/docs/eino/core_modules/flow_integration_components/multi_agent_hosting/) ## Agent Development Kit (ADK) The ADK provides a higher-level runtime for building, composing, and deploying agents. It sits above the Graph layer and introduces Agent, Skill, and Middleware abstractions. - [ADK overview](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/) - [Agent quickstart](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_quickstart/) - [Agent interface](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_interface/) - [Agent collaboration](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_collaboration/) - [Agent implementations](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/) - [ChatModel agent](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/chat_model/) - [Workflow agent](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/workflow/) - [Supervisor agent](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/supervisor/) - [Plan-and-execute agent](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/plan_execute/) - [Deep agents](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_implementation/deepagents/) - [Agent extension](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_extension/) - [Human-in-the-loop (HITL)](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/agent_hitl/) - [ADK callbacks](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/adk_agent_callback/) - [ChatModelAgent middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/) - [Filesystem middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_filesystem/) - [Skill middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_skill/) - [Summarization middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/) - [Plan-task middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_plantask/) - [Tool-search middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_toolsearch/) - [Tool-reduction middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_toolreduction/) - [Patch-toolcalls middleware](https://www.cloudwego.io/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_patchtoolcalls/) ## DevOps Tooling - [IDE plugin guide](https://www.cloudwego.io/docs/eino/core_modules/devops/ide_plugin_guide/) - [Visual orchestration plugin](https://www.cloudwego.io/docs/eino/core_modules/devops/visual_orchestration_plugin_guide/) - [Visual debug plugin](https://www.cloudwego.io/docs/eino/core_modules/devops/visual_debug_plugin_guide/) ## Ecosystem Integrations (eino-ext) - [ChatModel integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/chat_model/) - [Document integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/document/) - [Embedding integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/embedding/) - [Tool integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/tool/) - [Callback integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/callbacks/) - [Indexer integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/indexer/) - [Retriever integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/retriever/) - [ChatTemplate integrations](https://www.cloudwego.io/docs/eino/ecosystem_integration/chat_template/) ## Release Notes & Migration - [v0.1](https://www.cloudwego.io/docs/eino/release_notes_and_migration/v01_first_release/) - [v0.2](https://www.cloudwego.io/docs/eino/release_notes_and_migration/v02_second_release/) - [v0.3 — breaking changes](https://www.cloudwego.io/docs/eino/release_notes_and_migration/v03_tiny_break_change/) - [v0.4 — compose optimization](https://www.cloudwego.io/docs/eino/release_notes_and_migration/eino_v0.4._-compose_optimization/) - [v0.5 — ADK implementation](https://www.cloudwego.io/docs/eino/release_notes_and_migration/eino_v0.5._-adk_implementation/) - [v0.6 — JSON schema optimization](https://www.cloudwego.io/docs/eino/release_notes_and_migration/eino_v0.6._-jsonschema_optimization/) - [v0.7 — interrupt/resume refactor](https://www.cloudwego.io/docs/eino/release_notes_and_migration/eino_v0.7._-interrupt_resume_refactor/) - [v0.8 — ADK middlewares](https://www.cloudwego.io/docs/eino/release_notes_and_migration/eino_v0.8._-adk_middlewares/) ## FAQ - [FAQ](https://www.cloudwego.io/docs/eino/faq/) ================================================ FILE: schema/doc.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package schema defines the core data structures and utilities shared across // all Eino components. // // # Key Types // // [Message] is the universal unit of communication between users, models, and // tools. It carries role, text content, multimodal media, tool calls, and // response metadata. Helper constructors — [UserMessage], [SystemMessage], // [AssistantMessage], [ToolMessage] — cover the most common cases. // // [Document] represents a piece of text with a metadata map. Typed accessors // (Score, SubIndexes, DenseVector, SparseVector, DSLInfo, ExtraInfo) read and // write well-known metadata keys so pipeline stages can pass structured data // without coupling to specific struct types. // // [ToolInfo] describes a tool's name, description, and parameter schema. // Parameters can be declared either as a [ParameterInfo] map (simple, struct- // like) or as a raw [jsonschema.Schema] (full JSON Schema 2020-12 expressiveness). // [ToolChoice] controls whether the model must, may, or must not call tools. // // # Streaming // // [StreamReader] and [StreamWriter] are the building blocks for streaming data // through Eino pipelines. Create a linked pair with [Pipe]: // // sr, sw := schema.Pipe[*schema.Message](10) // go func() { // defer sw.Close() // sw.Send(chunk, nil) // }() // defer sr.Close() // for { // chunk, err := sr.Recv() // if errors.Is(err, io.EOF) { break } // } // // Important constraints: // - A StreamReader is read-once: only one goroutine may call Recv. // - Always call Close, even when the loop ends on io.EOF, to release resources. // - To give the same stream to multiple consumers, call [StreamReader.Copy]. // // # Four Streaming Paradigms // // Eino components and Lambda functions are classified by their input/output // streaming shape. The framework automatically bridges mismatches: // // - Invoke: non-streaming in, non-streaming out (ping-pong). // - Stream: non-streaming in, StreamReader out (server-streaming). ChatModel // and Tool support this. // - Collect: StreamReader in, non-streaming out (client-streaming). Useful // for branch conditions that decide after the first chunk. // - Transform: StreamReader in, StreamReader out (bidirectional). // // When an upstream node outputs T but a downstream node only accepts // StreamReader[T], the framework wraps T in a single-chunk StreamReader — // this is called a "fake stream". It satisfies the interface but does NOT // reduce time-to-first-chunk. Conversely, when a downstream node only accepts // T but the upstream outputs StreamReader[T], the framework automatically // concatenates the stream into a complete T. // // Utility functions: // - [StreamReaderFromArray] wraps a slice as a stream (useful in tests). // - [MergeStreamReaders] fans-in multiple streams into one. // - [MergeNamedStreamReaders] like MergeStreamReaders but emits [SourceEOF] // when each named source ends, useful for tracking per-source completion. // - [StreamReaderWithConvert] transforms element types; return [ErrNoValue] // from the convert function to skip an element. // // See https://www.cloudwego.io/docs/eino/core_modules/chain_and_graph_orchestration/stream_programming_essentials/ package schema ================================================ FILE: schema/document.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema const ( docMetaDataKeySubIndexes = "_sub_indexes" docMetaDataKeyScore = "_score" docMetaDataKeyExtraInfo = "_extra_info" docMetaDataKeyDSL = "_dsl" docMetaDataKeyDenseVector = "_dense_vector" docMetaDataKeySparseVector = "_sparse_vector" ) // Document is a piece of text with a metadata map. It is the shared currency // between Loader, Transformer, Indexer, and Retriever components. // // Metadata is an open map[string]any that lets pipeline stages attach typed // values to a document without creating a new struct. Well-known keys are // managed through typed accessor methods — Score, SubIndexes, DenseVector, // SparseVector, DSLInfo, ExtraInfo — so callers never need to reference the // raw key strings. // // Transformer implementations should preserve existing metadata and merge new // keys rather than replacing the map outright, so provenance information // accumulated by earlier stages is not lost. type Document struct { // ID is the unique identifier of the document. ID string `json:"id"` // Content is the content of the document. Content string `json:"content"` // MetaData is the metadata of the document, can be used to store extra information. MetaData map[string]any `json:"meta_data"` } // String returns the content of the document. func (d *Document) String() string { return d.Content } // WithSubIndexes sets the sub-indexes on the document metadata and returns the // document for chaining. Sub-indexes let an Indexer route a document into // multiple logical partitions of a vector store simultaneously. // Use [Document.SubIndexes] to retrieve them. func (d *Document) WithSubIndexes(indexes []string) *Document { if d.MetaData == nil { d.MetaData = make(map[string]any) } d.MetaData[docMetaDataKeySubIndexes] = indexes return d } // SubIndexes returns the sub indexes of the document. // can use doc.WithSubIndexes() to set the sub indexes. func (d *Document) SubIndexes() []string { if d.MetaData == nil { return nil } indexes, ok := d.MetaData[docMetaDataKeySubIndexes].([]string) if ok { return indexes } return nil } // WithScore sets the relevance score on the document, typically written by a // Retriever after ranking results. A higher score means higher relevance. // Note: [retriever.WithScoreThreshold] filters by this value, not sort order. // Use [Document.Score] to retrieve it. func (d *Document) WithScore(score float64) *Document { if d.MetaData == nil { d.MetaData = make(map[string]any) } d.MetaData[docMetaDataKeyScore] = score return d } // Score returns the score of the document. // can use doc.WithScore() to set the score. func (d *Document) Score() float64 { if d.MetaData == nil { return 0 } score, ok := d.MetaData[docMetaDataKeyScore].(float64) if ok { return score } return 0 } // WithExtraInfo sets the extra info of the document. // can use doc.ExtraInfo() to get the extra info. func (d *Document) WithExtraInfo(extraInfo string) *Document { if d.MetaData == nil { d.MetaData = make(map[string]any) } d.MetaData[docMetaDataKeyExtraInfo] = extraInfo return d } // ExtraInfo returns the extra info of the document. // can use doc.WithExtraInfo() to set the extra info. func (d *Document) ExtraInfo() string { if d.MetaData == nil { return "" } extraInfo, ok := d.MetaData[docMetaDataKeyExtraInfo].(string) if ok { return extraInfo } return "" } // WithDSLInfo attaches a domain-specific-language query description to the // document. This is consumed by Retriever implementations that support // structured queries (e.g., filter expressions) alongside vector search. // Use [Document.DSLInfo] to retrieve it. func (d *Document) WithDSLInfo(dslInfo map[string]any) *Document { if d.MetaData == nil { d.MetaData = make(map[string]any) } d.MetaData[docMetaDataKeyDSL] = dslInfo return d } // DSLInfo returns the dsl info of the document. // can use doc.WithDSLInfo() to set the dsl info. func (d *Document) DSLInfo() map[string]any { if d.MetaData == nil { return nil } dslInfo, ok := d.MetaData[docMetaDataKeyDSL].(map[string]any) if ok { return dslInfo } return nil } // WithDenseVector sets the dense vector of the document. // can use doc.DenseVector() to get the dense vector. func (d *Document) WithDenseVector(vector []float64) *Document { if d.MetaData == nil { d.MetaData = make(map[string]any) } d.MetaData[docMetaDataKeyDenseVector] = vector return d } // DenseVector returns the dense vector of the document. // can use doc.WithDenseVector() to set the dense vector. func (d *Document) DenseVector() []float64 { if d.MetaData == nil { return nil } vector, ok := d.MetaData[docMetaDataKeyDenseVector].([]float64) if ok { return vector } return nil } // WithSparseVector sets the sparse vector of the document, key indices -> value vector. // can use doc.SparseVector() to get the sparse vector. func (d *Document) WithSparseVector(sparse map[int]float64) *Document { if d.MetaData == nil { d.MetaData = make(map[string]any) } d.MetaData[docMetaDataKeySparseVector] = sparse return d } // SparseVector returns the sparse vector of the document, key indices -> value vector. // can use doc.WithSparseVector() to set the sparse vector. func (d *Document) SparseVector() map[int]float64 { if d.MetaData == nil { return nil } sparse, ok := d.MetaData[docMetaDataKeySparseVector].(map[int]float64) if ok { return sparse } return nil } ================================================ FILE: schema/document_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "testing" "github.com/smartystreets/goconvey/convey" ) func TestDocument(t *testing.T) { convey.Convey("test document", t, func() { var ( subIndexes = []string{"hello", "bye"} score = 1.1 extraInfo = "asd" dslInfo = map[string]any{"hello": true} vector = []float64{1.1, 2.2} ) d := &Document{ ID: "asd", Content: "qwe", MetaData: nil, } d.WithSubIndexes(subIndexes). WithDenseVector(vector). WithScore(score). WithExtraInfo(extraInfo). WithDSLInfo(dslInfo) convey.So(d.SubIndexes(), convey.ShouldEqual, subIndexes) convey.So(d.Score(), convey.ShouldEqual, score) convey.So(d.ExtraInfo(), convey.ShouldEqual, extraInfo) convey.So(d.DSLInfo(), convey.ShouldEqual, dslInfo) convey.So(d.DenseVector(), convey.ShouldEqual, vector) }) } ================================================ FILE: schema/message.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "context" "fmt" "io" "reflect" "sort" "strings" "sync" "text/template" "github.com/nikolalohinski/gonja" "github.com/nikolalohinski/gonja/config" "github.com/nikolalohinski/gonja/nodes" "github.com/nikolalohinski/gonja/parser" "github.com/slongfield/pyfmt" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/generic" ) func init() { internal.RegisterStreamChunkConcatFunc(ConcatMessages) internal.RegisterStreamChunkConcatFunc(ConcatMessageArray) internal.RegisterStreamChunkConcatFunc(ConcatToolResults) } // ConcatMessageArray merges aligned slices of messages into a single slice, // concatenating messages at the same index across the input arrays. func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { arrayLen := len(mas[0]) ret := make([]*Message, arrayLen) slicesToConcat := make([][]*Message, arrayLen) for _, ma := range mas { if len(ma) != arrayLen { return nil, fmt.Errorf("unexpected array length. "+ "Got %d, expected %d", len(ma), arrayLen) } for i := 0; i < arrayLen; i++ { m := ma[i] if m != nil { slicesToConcat[i] = append(slicesToConcat[i], m) } } } for i, slice := range slicesToConcat { if len(slice) == 0 { ret[i] = nil } else if len(slice) == 1 { ret[i] = slice[0] } else { cm, err := ConcatMessages(slice) if err != nil { return nil, err } ret[i] = cm } } return ret, nil } // FormatType used by MessageTemplate.Format type FormatType uint8 const ( // FString Supported by pyfmt(github.com/slongfield/pyfmt), which is an implementation of https://peps.python.org/pep-3101/. FString FormatType = 0 // GoTemplate https://pkg.go.dev/text/template. GoTemplate FormatType = 1 // Jinja2 Supported by gonja(github.com/nikolalohinski/gonja), which is a implementation of https://jinja.palletsprojects.com/en/3.1.x/templates/. Jinja2 FormatType = 2 ) // RoleType is the type of the role of a message. type RoleType string const ( // Assistant is the role of an assistant, means the message is returned by ChatModel. Assistant RoleType = "assistant" // User is the role of a user, means the message is a user message. User RoleType = "user" // System is the role of a system, means the message is a system message. System RoleType = "system" // Tool is the role of a tool, means the message is a tool call output. Tool RoleType = "tool" ) // FunctionCall is the function call in a message. // It's used in Assistant Message. type FunctionCall struct { // Name is the name of the function to call, it can be used to identify the specific function. Name string `json:"name,omitempty"` // Arguments is the arguments to call the function with, in JSON format. Arguments string `json:"arguments,omitempty"` } // ToolCall is the tool call in a message. // It's used in Assistant Message when there are tool calls should be made. type ToolCall struct { // Index is used when there are multiple tool calls in a message. // In stream mode, it's used to identify the chunk of the tool call for merging. Index *int `json:"index,omitempty"` // ID is the id of the tool call, it can be used to identify the specific tool call. ID string `json:"id"` // Type is the type of the tool call, default is "function". Type string `json:"type"` // Function is the function call to be made. Function FunctionCall `json:"function"` // Extra is used to store extra information for the tool call. Extra map[string]any `json:"extra,omitempty"` } // ImageURLDetail is the detail of the image url. type ImageURLDetail string const ( // ImageURLDetailHigh means the high quality image url. ImageURLDetailHigh ImageURLDetail = "high" // ImageURLDetailLow means the low quality image url. ImageURLDetailLow ImageURLDetail = "low" // ImageURLDetailAuto means the auto quality image url. ImageURLDetailAuto ImageURLDetail = "auto" ) // MessagePartCommon represents the common abstract components for input and output of multi-modal types. type MessagePartCommon struct { // URL is primarily used for HTTP or HTTPS access links. // For data in the format 'data:[][;base64],' (the 'data' URL Schema of RFC-2397 (https://www.rfc-editor.org/rfc/rfc2397)), // it is recommended to use Base64Data and MIMEType fields separately instead. URL *string `json:"url,omitempty"` // Base64Data represents the binary data in Base64 encoded string format. Base64Data *string `json:"base64data,omitempty"` // MIMEType is the mime type , eg."image/png",""audio/wav" etc. MIMEType string `json:"mime_type,omitempty"` // Deprecated: Use MessageOutputPart.Extra or MessageInputPart.Extra to set additional metadata instead. Extra map[string]any `json:"extra,omitempty"` } // MessageInputImage is used to represent an image part in message. // Choose either URL or Base64Data. type MessageInputImage struct { MessagePartCommon // Detail is the quality of the image url. Detail ImageURLDetail `json:"detail,omitempty"` } // MessageInputAudio is used to represent an audio part in message. // Choose either URL or Base64Data. type MessageInputAudio struct { MessagePartCommon } // MessageInputVideo is used to represent a video part in message. // Choose either URL or Base64Data. type MessageInputVideo struct { MessagePartCommon } // MessageInputFile is used to represent a file part in message. // Choose either URL or Base64Data. type MessageInputFile struct { MessagePartCommon // Name represents the filename. // Optional. Name string `json:"name,omitempty"` } // MessageInputPart represents the input part of message. type MessageInputPart struct { Type ChatMessagePartType `json:"type"` Text string `json:"text,omitempty"` // Image is the image input of the part, it's used when Type is "image_url". Image *MessageInputImage `json:"image,omitempty"` // Audio is the audio input of the part, it's used when Type is "audio_url". Audio *MessageInputAudio `json:"audio,omitempty"` // Video is the video input of the part, it's used when Type is "video_url". Video *MessageInputVideo `json:"video,omitempty"` // File is the file input of the part, it's used when Type is "file_url". File *MessageInputFile `json:"file,omitempty"` // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } // MessageOutputImage is used to represent an image part in message. type MessageOutputImage struct { MessagePartCommon } // MessageOutputAudio is used to represent an audio part in message. type MessageOutputAudio struct { MessagePartCommon } // MessageOutputVideo is used to represent a video part in message. type MessageOutputVideo struct { MessagePartCommon } // MessageOutputReasoning represents the reasoning content generated by reasoning models. // Some models produce reasoning steps before generating the final response. // This struct captures that reasoning output. type MessageOutputReasoning struct { // Text is either the thought summary or the raw reasoning text itself. Text string `json:"text,omitempty"` // Signature contains encrypted reasoning tokens. // Required by some models when passing reasoning context back in subsequent requests. Signature string `json:"signature,omitempty"` } // MessageStreamingMeta contains metadata for streaming responses. // It is used to track position of part when the model outputs multiple parts in a single response. type MessageStreamingMeta struct { // Index specifies the index position of this part in the final response. // This is useful for reassembling multiple reasoning/content parts in correct order. Index int `json:"index,omitempty"` } // MessageOutputPart represents a part of an assistant-generated message. // It can contain text, or multimedia content like images, audio, or video. type MessageOutputPart struct { // Type is the type of the part, e.g. "text", "image_url", "audio_url", "video_url". Type ChatMessagePartType `json:"type"` // Text is the text of the part, it's used when Type is "text". Text string `json:"text,omitempty"` // Image is the image output of the part, used when Type is ChatMessagePartTypeImageURL. Image *MessageOutputImage `json:"image,omitempty"` // Audio is the audio output of the part, used when Type is ChatMessagePartTypeAudioURL. Audio *MessageOutputAudio `json:"audio,omitempty"` // Video is the video output of the part, used when Type is ChatMessagePartTypeVideoURL. Video *MessageOutputVideo `json:"video,omitempty"` // Reasoning contains the reasoning content generated by the model. // Used when Type is ChatMessagePartTypeReasoning. Reasoning *MessageOutputReasoning `json:"reasoning,omitempty"` // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` // StreamingMeta contains metadata for streaming responses. // This field is typically used at runtime and not serialized. StreamingMeta *MessageStreamingMeta `json:"-"` } // ToolPartType defines the type of content in a tool output part. // It is used to distinguish between different types of multimodal content returned by tools. type ToolPartType string const ( // ToolPartTypeText means the part is a text. ToolPartTypeText ToolPartType = "text" // ToolPartTypeImage means the part is an image url. ToolPartTypeImage ToolPartType = "image" // ToolPartTypeAudio means the part is an audio url. ToolPartTypeAudio ToolPartType = "audio" // ToolPartTypeVideo means the part is a video url. ToolPartTypeVideo ToolPartType = "video" // ToolPartTypeFile means the part is a file url. ToolPartTypeFile ToolPartType = "file" ) // ToolOutputImage represents an image in tool output. // It contains URL or Base64-encoded data along with MIME type information. type ToolOutputImage struct { MessagePartCommon } // ToolOutputAudio represents an audio file in tool output. // It contains URL or Base64-encoded data along with MIME type information. type ToolOutputAudio struct { MessagePartCommon } // ToolOutputVideo represents a video file in tool output. // It contains URL or Base64-encoded data along with MIME type information. type ToolOutputVideo struct { MessagePartCommon } // ToolOutputFile represents a generic file in tool output. // It contains URL or Base64-encoded data along with MIME type information. type ToolOutputFile struct { MessagePartCommon } // ToolOutputPart represents a part of tool execution output. // It supports streaming scenarios through the Index field for chunk merging. type ToolOutputPart struct { // Type is the type of the part, e.g., "text", "image_url", "audio_url", "video_url". Type ToolPartType `json:"type"` // Text is the text content, used when Type is "text". Text string `json:"text,omitempty"` // Image is the image content, used when Type is ToolPartTypeImage. Image *ToolOutputImage `json:"image,omitempty"` // Audio is the audio content, used when Type is ToolPartTypeAudio. Audio *ToolOutputAudio `json:"audio,omitempty"` // Video is the video content, used when Type is ToolPartTypeVideo. Video *ToolOutputVideo `json:"video,omitempty"` // File is the file content, used when Type is ToolPartTypeFile. File *ToolOutputFile `json:"file,omitempty"` // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } // ToolArgument contains the input information for a tool call. // It is used to pass tool call arguments to enhanced tools. type ToolArgument struct { // Text contains the arguments for the tool call in JSON format. Text string `json:"text,omitempty"` } // ToolResult represents the structured multimodal output from a tool execution. // It is used when a tool needs to return more than just a simple string, // such as images, files, or other structured data. type ToolResult struct { // Parts contains the multimodal output parts. Each part can be a different // type of content, like text, an image, or a file. Parts []ToolOutputPart `json:"parts,omitempty"` } func convToolOutputPartToMessageInputPart(toolPart ToolOutputPart) (MessageInputPart, error) { switch toolPart.Type { case ToolPartTypeText: return MessageInputPart{ Type: ChatMessagePartTypeText, Text: toolPart.Text, Extra: toolPart.Extra, }, nil case ToolPartTypeImage: if toolPart.Image == nil { return MessageInputPart{}, fmt.Errorf("image content is nil for tool part type %v", toolPart.Type) } return MessageInputPart{ Type: ChatMessagePartTypeImageURL, Image: &MessageInputImage{MessagePartCommon: toolPart.Image.MessagePartCommon}, Extra: toolPart.Extra, }, nil case ToolPartTypeAudio: if toolPart.Audio == nil { return MessageInputPart{}, fmt.Errorf("audio content is nil for tool part type %v", toolPart.Type) } return MessageInputPart{ Type: ChatMessagePartTypeAudioURL, Audio: &MessageInputAudio{MessagePartCommon: toolPart.Audio.MessagePartCommon}, Extra: toolPart.Extra, }, nil case ToolPartTypeVideo: if toolPart.Video == nil { return MessageInputPart{}, fmt.Errorf("video content is nil for tool part type %v", toolPart.Type) } return MessageInputPart{ Type: ChatMessagePartTypeVideoURL, Video: &MessageInputVideo{MessagePartCommon: toolPart.Video.MessagePartCommon}, Extra: toolPart.Extra, }, nil case ToolPartTypeFile: if toolPart.File == nil { return MessageInputPart{}, fmt.Errorf("file content is nil for tool part type %v", toolPart.Type) } return MessageInputPart{ Type: ChatMessagePartTypeFileURL, File: &MessageInputFile{MessagePartCommon: toolPart.File.MessagePartCommon}, Extra: toolPart.Extra, }, nil default: return MessageInputPart{}, fmt.Errorf("unknown tool part type: %v", toolPart.Type) } } // ToMessageInputParts converts ToolOutputPart slice to MessageInputPart slice. // This is used when passing tool results as input to the model. // // Parameters: // - None (method receiver is *ToolResult) // // Returns: // - []MessageInputPart: The converted message input parts that can be used in a Message. // - error: An error if conversion fails due to unknown part types or nil content fields. // // Example: // // toolResult := &schema.ToolResult{ // Parts: []schema.ToolOutputPart{ // {Type: schema.ToolPartTypeText, Text: "Result text"}, // {Type: schema.ToolPartTypeImage, Image: &schema.ToolOutputImage{...}}, // }, // } // inputParts, err := toolResult.ToMessageInputParts() func (tr *ToolResult) ToMessageInputParts() ([]MessageInputPart, error) { if tr == nil || len(tr.Parts) == 0 { return nil, nil } result := make([]MessageInputPart, len(tr.Parts)) for i, part := range tr.Parts { var err error result[i], err = convToolOutputPartToMessageInputPart(part) if err != nil { return nil, err } } return result, nil } // Deprecated: This struct is deprecated as the MultiContent field is deprecated. // For the image input part of the model, use MessageInputImage. // For the image output part of the model, use MessageOutputImage. // Choose either URL or URI. // If your model implementation supports it, URL could embed inline image data // as defined in RFC-2397. type ChatMessageImageURL struct { // URL can either be a traditional URL or a special URL conforming to RFC-2397 (https://www.rfc-editor.org/rfc/rfc2397). // double check with model implementations for detailed instructions on how to use this. URL string `json:"url,omitempty"` URI string `json:"uri,omitempty"` // Detail is the quality of the image url. Detail ImageURLDetail `json:"detail,omitempty"` // MIMEType is the mime type of the image, eg. "image/png". MIMEType string `json:"mime_type,omitempty"` // Extra is used to store extra information for the image url. Extra map[string]any `json:"extra,omitempty"` } // ChatMessagePartType is the type of the part in a chat message. type ChatMessagePartType string const ( // ChatMessagePartTypeText means the part is a text. ChatMessagePartTypeText ChatMessagePartType = "text" // ChatMessagePartTypeImageURL means the part is an image url. ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" // ChatMessagePartTypeAudioURL means the part is an audio url. ChatMessagePartTypeAudioURL ChatMessagePartType = "audio_url" // ChatMessagePartTypeVideoURL means the part is a video url. ChatMessagePartTypeVideoURL ChatMessagePartType = "video_url" // ChatMessagePartTypeFileURL means the part is a file url. ChatMessagePartTypeFileURL ChatMessagePartType = "file_url" // ChatMessagePartTypeReasoning means the part is a reasoning block. ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning" ) // Deprecated: This struct is deprecated as the MultiContent field is deprecated. // For the audio input part of the model, use MessageInputAudio. // For the audio output part of the model, use MessageOutputAudio. // Choose either URL or URI. // If supported, URL may embed inline audio data per RFC-2397. type ChatMessageAudioURL struct { // URL can either be a traditional URL or a special URL conforming to RFC-2397 (https://www.rfc-editor.org/rfc/rfc2397). // double check with model implementations for detailed instructions on how to use this. URL string `json:"url,omitempty"` URI string `json:"uri,omitempty"` // MIMEType is the mime type of the audio, eg. "audio/wav" or "audio/ogg". MIMEType string `json:"mime_type,omitempty"` // Extra is used to store extra information for the audio url. Extra map[string]any `json:"extra,omitempty"` } // Deprecated: This struct is deprecated as the MultiContent field is deprecated. // For the video input part of the model, use MessageInputVideo. // For the video output part of the model, use MessageOutputVideo. // Choose either URL or URI. // If supported, URL may embed inline video data per RFC-2397. type ChatMessageVideoURL struct { // URL can either be a traditional URL or a special URL conforming to RFC-2397 (https://www.rfc-editor.org/rfc/rfc2397). // double check with model implementations for detailed instructions on how to use this. URL string `json:"url,omitempty"` URI string `json:"uri,omitempty"` // MIMEType is the mime type of the video, eg. "video/mp4". MIMEType string `json:"mime_type,omitempty"` // Extra is used to store extra information for the video url. Extra map[string]any `json:"extra,omitempty"` } // Deprecated: This struct is deprecated as the MultiContent field is deprecated. // For the file input part of the model, use MessageInputFile. // Choose either URL or URI. type ChatMessageFileURL struct { URL string `json:"url,omitempty"` URI string `json:"uri,omitempty"` // MIMEType is the mime type of the file, eg. "application/pdf", "text/plain". MIMEType string `json:"mime_type,omitempty"` // Name is the name of the file. Name string `json:"name,omitempty"` // Extra is used to store extra information for the file url. Extra map[string]any `json:"extra,omitempty"` } // Deprecated: This struct is deprecated as the MultiContent field is deprecated. // For model input, use MessageInputPart. For model output, use MessageOutputPart. type ChatMessagePart struct { // Type is the type of the part, eg. "text", "image_url", "audio_url", "video_url", "file_url". Type ChatMessagePartType `json:"type,omitempty"` // Text is the text of the part, it's used when Type is "text". Text string `json:"text,omitempty"` // ImageURL is the image url of the part, it's used when Type is "image_url". ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` // AudioURL is the audio url of the part, it's used when Type is "audio_url". AudioURL *ChatMessageAudioURL `json:"audio_url,omitempty"` // VideoURL is the video url of the part, it's used when Type is "video_url". VideoURL *ChatMessageVideoURL `json:"video_url,omitempty"` // FileURL is the file url of the part, it's used when Type is "file_url". FileURL *ChatMessageFileURL `json:"file_url,omitempty"` } // LogProbs is the top-level structure containing the log probability information. type LogProbs struct { // Content is a list of message content tokens with log probability information. Content []LogProb `json:"content"` } // LogProb represents the probability information for a token. type LogProb struct { // Token represents the text of the token, which is a contiguous sequence of characters // (e.g., a word, part of a word, or punctuation) as understood by the tokenization process used by the language model. Token string `json:"token"` // LogProb is the log probability of this token, if it is within the top 20 most likely tokens. // Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. LogProb float64 `json:"logprob"` // Bytes is a list of integers representing the UTF-8 bytes representation of the token. // Useful in instances where characters are represented by multiple tokens and // their byte representations must be combined to generate the correct text // representation. Can be `null` if there is no bytes representation for the token. Bytes []int64 `json:"bytes,omitempty"` // Omitting the field if it is null // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. // In rare cases, there may be fewer than the number of requested top_logprobs returned. TopLogProbs []TopLogProb `json:"top_logprobs"` } // TopLogProb describes a likely token and its log probability at a position. type TopLogProb struct { // Token represents the text of the token, which is a contiguous sequence of characters // (e.g., a word, part of a word, or punctuation) as understood by the tokenization process used by the language model. Token string `json:"token"` // LogProb is the log probability of this token, if it is within the top 20 most likely tokens. // Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. LogProb float64 `json:"logprob"` // Bytes is a list of integers representing the UTF-8 bytes representation of the token. // Useful in instances where characters are represented by multiple tokens and // their byte representations must be combined to generate the correct text // representation. Can be `null` if there is no bytes representation for the token. Bytes []int64 `json:"bytes,omitempty"` } // ResponseMeta collects meta information about a chat response. type ResponseMeta struct { // FinishReason is the reason why the chat response is finished. // It's usually "stop", "length", "tool_calls", "content_filter", "null". This is defined by chat model implementation. FinishReason string `json:"finish_reason,omitempty"` // Usage is the token usage of the chat response, whether usage exists depends on whether the chat model implementation returns. Usage *TokenUsage `json:"usage,omitempty"` // LogProbs is Log probability information. LogProbs *LogProbs `json:"logprobs,omitempty"` } // Message denotes the data structure for model input and output, originating from either user input or model return. // It supports both text-only and multimodal content. // // For text-only input from a user, use the Content field: // // &schema.Message{ // Role: schema.User, // Content: "What is the capital of France?", // } // // For multimodal input from a user, use the UserInputMultiContent field. // This allows combining text with other media like images: // // &schema.Message{ // Role: schema.User, // UserInputMultiContent: []schema.MessageInputPart{ // {Type: schema.ChatMessagePartTypeText, Text: "What is in this image?"}, // {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{ // MessagePartCommon: schema.MessagePartCommon{ // URL: toPtr("https://example.com/cat.jpg"), // }, // Detail: schema.ImageURLDetailHigh, // }}, // }, // } // // When the model returns multimodal content, it is available in the AssistantGenMultiContent field: // // &schema.Message{ // Role: schema.Assistant, // AssistantGenMultiContent: []schema.MessageOutputPart{ // {Type: schema.ChatMessagePartTypeText, Text: "Here is the generated image:"}, // {Type: schema.ChatMessagePartTypeImage, Image: &schema.MessageOutputImage{ // MessagePartCommon: schema.MessagePartCommon{ // Base64Data: toPtr("base64_image_binary"), // MIMEType: "image/png", // }, // }}, // }, // } type Message struct { Role RoleType `json:"role"` // Content is for user text input and model text output. Content string `json:"content"` // if MultiContent is not empty, use this instead of Content // if MultiContent is empty, use Content // Deprecated: Use UserInputMultiContent for user multimodal inputs and AssistantGenMultiContent for model multimodal outputs. MultiContent []ChatMessagePart `json:"multi_content,omitempty"` // UserInputMultiContent passes multimodal content provided by the user to the model. UserInputMultiContent []MessageInputPart `json:"user_input_multi_content,omitempty"` // AssistantGenMultiContent is for receiving multimodal output from the model. AssistantGenMultiContent []MessageOutputPart `json:"assistant_output_multi_content,omitempty"` Name string `json:"name,omitempty"` // only for AssistantMessage ToolCalls []ToolCall `json:"tool_calls,omitempty"` // only for ToolMessage ToolCallID string `json:"tool_call_id,omitempty"` // only for ToolMessage ToolName string `json:"tool_name,omitempty"` ResponseMeta *ResponseMeta `json:"response_meta,omitempty"` // ReasoningContent is the thinking process of the model, which will be included when the model returns reasoning content. ReasoningContent string `json:"reasoning_content,omitempty"` // customized information for model implementation Extra map[string]any `json:"extra,omitempty"` } // TokenUsage Represents the token usage of chat model request. type TokenUsage struct { // PromptTokens is the number of prompt tokens, including all the input tokens of this request. PromptTokens int `json:"prompt_tokens"` // PromptTokenDetails is a breakdown of the prompt tokens. PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` // CompletionTokensDetails is breakdown of completion tokens. CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. ReasoningTokens int `json:"reasoning_tokens,omitempty"` } // PromptTokenDetails provides a breakdown of prompt token usage. type PromptTokenDetails struct { // Cached tokens present in the prompt. CachedTokens int `json:"cached_tokens"` } var _ MessagesTemplate = &Message{} var _ MessagesTemplate = MessagesPlaceholder("", false) // MessagesTemplate is the interface for messages template. // It's used to render a template to a list of messages. // e.g. // // chatTemplate := prompt.FromMessages( // schema.SystemMessage("you are eino helper"), // schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params // ) // msgs, err := chatTemplate.Format(ctx, params) type MessagesTemplate interface { Format(ctx context.Context, vs map[string]any, formatType FormatType) ([]*Message, error) } type messagesPlaceholder struct { key string optional bool } // MessagesPlaceholder can render a placeholder to a list of messages in params. // e.g. // // placeholder := MessagesPlaceholder("history", false) // params := map[string]any{ // "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, // "query": "how to use eino?", // } // chatTemplate := chatTpl := prompt.FromMessages( // schema.SystemMessage("you are eino helper"), // schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params // ) // msgs, err := chatTemplate.Format(ctx, params) func MessagesPlaceholder(key string, optional bool) MessagesTemplate { return &messagesPlaceholder{ key: key, optional: optional, } } // Format just return the messages of specified key. // because it's a placeholder. // e.g. // // placeholder := MessagesPlaceholder("history", false) // params := map[string]any{ // "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, // "query": "how to use eino?", // } // msgs, err := placeholder.Format(ctx, params) // <= this will return the value of "history" in params func (p *messagesPlaceholder) Format(_ context.Context, vs map[string]any, _ FormatType) ([]*Message, error) { v, ok := vs[p.key] if !ok { if p.optional { return []*Message{}, nil } return nil, fmt.Errorf("message placeholder format: %s not found", p.key) } msgs, ok := v.([]*Message) if !ok { return nil, fmt.Errorf("only messages can be used to format message placeholder, key: %v, actual type: %v", p.key, reflect.TypeOf(v)) } return msgs, nil } func formatContent(content string, vs map[string]any, formatType FormatType) (string, error) { switch formatType { case FString: return pyfmt.Fmt(content, vs) case GoTemplate: parsedTmpl, err := template.New("template"). Option("missingkey=error"). Parse(content) if err != nil { return "", err } sb := new(strings.Builder) err = parsedTmpl.Execute(sb, vs) if err != nil { return "", err } return sb.String(), nil case Jinja2: env, err := getJinjaEnv() if err != nil { return "", err } tpl, err := env.FromString(content) if err != nil { return "", err } out, err := tpl.Execute(vs) if err != nil { return "", err } return out, nil default: return "", fmt.Errorf("unknown format type: %v", formatType) } } // Format returns the messages after rendering by the given formatType. // e.g. // // msg := schema.UserMessage("hello world, {name}") // msgs, err := msg.Format(ctx, map[string]any{"name": "eino"}, schema.FString) // <= this will render the content of msg by pyfmt // // msgs[0].Content will be "hello world, eino" func (m *Message) Format(_ context.Context, vs map[string]any, formatType FormatType) ([]*Message, error) { c, err := formatContent(m.Content, vs, formatType) if err != nil { return nil, err } copied := *m copied.Content = c if len(m.MultiContent) > 0 { copied.MultiContent, err = formatMultiContent(m.MultiContent, vs, formatType) if err != nil { return nil, err } } if len(m.UserInputMultiContent) > 0 { copied.UserInputMultiContent, err = formatUserInputMultiContent(m.UserInputMultiContent, vs, formatType) if err != nil { return nil, err } } return []*Message{&copied}, nil } func formatMultiContent(multiContent []ChatMessagePart, vs map[string]any, formatType FormatType) ([]ChatMessagePart, error) { copiedMC := make([]ChatMessagePart, len(multiContent)) copy(copiedMC, multiContent) for i, mc := range copiedMC { switch mc.Type { case ChatMessagePartTypeText: nmc, err := formatContent(mc.Text, vs, formatType) if err != nil { return nil, err } copiedMC[i].Text = nmc case ChatMessagePartTypeImageURL: if mc.ImageURL == nil { continue } url, err := formatContent(mc.ImageURL.URL, vs, formatType) if err != nil { return nil, err } copiedMC[i].ImageURL.URL = url case ChatMessagePartTypeAudioURL: if mc.AudioURL == nil { continue } url, err := formatContent(mc.AudioURL.URL, vs, formatType) if err != nil { return nil, err } copiedMC[i].AudioURL.URL = url case ChatMessagePartTypeVideoURL: if mc.VideoURL == nil { continue } url, err := formatContent(mc.VideoURL.URL, vs, formatType) if err != nil { return nil, err } copiedMC[i].VideoURL.URL = url case ChatMessagePartTypeFileURL: if mc.FileURL == nil { continue } url, err := formatContent(mc.FileURL.URL, vs, formatType) if err != nil { return nil, err } copiedMC[i].FileURL.URL = url } } return copiedMC, nil } func formatUserInputMultiContent(userInputMultiContent []MessageInputPart, vs map[string]any, formatType FormatType) ([]MessageInputPart, error) { copiedUIMC := make([]MessageInputPart, len(userInputMultiContent)) copy(copiedUIMC, userInputMultiContent) for i, uimc := range copiedUIMC { switch uimc.Type { case ChatMessagePartTypeText: text, err := formatContent(uimc.Text, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].Text = text case ChatMessagePartTypeImageURL: if uimc.Image == nil { continue } if uimc.Image.URL != nil && *uimc.Image.URL != "" { url, err := formatContent(*uimc.Image.URL, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].Image.URL = &url } if uimc.Image.Base64Data != nil && *uimc.Image.Base64Data != "" { base64data, err := formatContent(*uimc.Image.Base64Data, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].Image.Base64Data = &base64data } case ChatMessagePartTypeAudioURL: if uimc.Audio == nil { continue } if uimc.Audio.URL != nil && *uimc.Audio.URL != "" { url, err := formatContent(*uimc.Audio.URL, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].Audio.URL = &url } if uimc.Audio.Base64Data != nil && *uimc.Audio.Base64Data != "" { base64data, err := formatContent(*uimc.Audio.Base64Data, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].Audio.Base64Data = &base64data } case ChatMessagePartTypeVideoURL: if uimc.Video == nil { continue } if uimc.Video.URL != nil && *uimc.Video.URL != "" { url, err := formatContent(*uimc.Video.URL, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].Video.URL = &url } if uimc.Video.Base64Data != nil && *uimc.Video.Base64Data != "" { base64data, err := formatContent(*uimc.Video.Base64Data, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].Video.Base64Data = &base64data } case ChatMessagePartTypeFileURL: if uimc.File == nil { continue } if uimc.File.URL != nil && *uimc.File.URL != "" { url, err := formatContent(*uimc.File.URL, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].File.URL = &url } if uimc.File.Base64Data != nil && *uimc.File.Base64Data != "" { base64data, err := formatContent(*uimc.File.Base64Data, vs, formatType) if err != nil { return nil, err } copiedUIMC[i].File.Base64Data = &base64data } } } return copiedUIMC, nil } // String returns the string representation of the message. // e.g. // // msg := schema.UserMessage("hello world") // fmt.Println(msg.String()) // Output will be: `user: hello world`` // // msg := schema.Message{ // Role: schema.Tool, // Content: "{...}", // ToolCallID: "callxxxx" // } // fmt.Println(msg.String()) // Output will be: // tool: {...} // call_id: callxxxx func (m *Message) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("%s: %s", m.Role, m.Content)) if len(m.UserInputMultiContent) > 0 { sb.WriteString("\nuser_input_multi_content:") for i, part := range m.UserInputMultiContent { sb.WriteString(fmt.Sprintf("\n [%d] %s", i, formatInputPart(part))) } } if len(m.AssistantGenMultiContent) > 0 { sb.WriteString("\nassistant_gen_multi_content:") for i, part := range m.AssistantGenMultiContent { sb.WriteString(fmt.Sprintf("\n [%d] %s", i, formatOutputPart(part))) } } if len(m.MultiContent) > 0 { sb.WriteString("\nmulti_content:") for i, part := range m.MultiContent { sb.WriteString(fmt.Sprintf("\n [%d] %s", i, formatChatMessagePart(part))) } } if len(m.ReasoningContent) > 0 { sb.WriteString("\nreasoning content:\n") sb.WriteString(m.ReasoningContent) } if len(m.ToolCalls) > 0 { sb.WriteString("\ntool_calls:\n") for _, tc := range m.ToolCalls { if tc.Index != nil { sb.WriteString(fmt.Sprintf("index[%d]:", *tc.Index)) } sb.WriteString(fmt.Sprintf("%+v\n", tc)) } } if m.ToolCallID != "" { sb.WriteString(fmt.Sprintf("\ntool_call_id: %s", m.ToolCallID)) } if m.ToolName != "" { sb.WriteString(fmt.Sprintf("\ntool_call_name: %s", m.ToolName)) } if m.ResponseMeta != nil { sb.WriteString(fmt.Sprintf("\nfinish_reason: %s", m.ResponseMeta.FinishReason)) if m.ResponseMeta.Usage != nil { sb.WriteString(fmt.Sprintf("\nusage: %v", m.ResponseMeta.Usage)) } } return sb.String() } func formatInputPart(part MessageInputPart) string { switch part.Type { case ChatMessagePartTypeText: return fmt.Sprintf("text: %s", part.Text) case ChatMessagePartTypeImageURL: return fmt.Sprintf("image: %s", formatMessageInputMedia(part.Image)) case ChatMessagePartTypeAudioURL: return fmt.Sprintf("audio: %s", formatMessageInputMedia(part.Audio)) case ChatMessagePartTypeVideoURL: return fmt.Sprintf("video: %s", formatMessageInputMedia(part.Video)) case ChatMessagePartTypeFileURL: return fmt.Sprintf("file: %s", formatMessageInputFile(part.File)) default: return fmt.Sprintf("unknown type: %s", part.Type) } } func formatMessageInputMedia[T MessageInputImage | MessageInputAudio | MessageInputVideo](media *T) string { if media == nil { return "" } var parts []string switch v := any(media).(type) { case *MessageInputImage: if v.URL != nil { parts = append(parts, fmt.Sprintf("url=%s", *v.URL)) } if v.Base64Data != nil { parts = append(parts, fmt.Sprintf("base64[%d bytes]", len(*v.Base64Data))) } if v.MIMEType != "" { parts = append(parts, fmt.Sprintf("mime=%s", v.MIMEType)) } if v.Detail != "" { parts = append(parts, fmt.Sprintf("detail=%s", v.Detail)) } if len(v.Extra) > 0 { parts = append(parts, fmt.Sprintf("extra=%v", v.Extra)) } case *MessageInputAudio: if v.URL != nil { parts = append(parts, fmt.Sprintf("url=%s", *v.URL)) } if v.Base64Data != nil { parts = append(parts, fmt.Sprintf("base64[%d bytes]", len(*v.Base64Data))) } if v.MIMEType != "" { parts = append(parts, fmt.Sprintf("mime=%s", v.MIMEType)) } if len(v.Extra) > 0 { parts = append(parts, fmt.Sprintf("extra=%v", v.Extra)) } case *MessageInputVideo: if v.URL != nil { parts = append(parts, fmt.Sprintf("url=%s", *v.URL)) } if v.Base64Data != nil { parts = append(parts, fmt.Sprintf("base64[%d bytes]", len(*v.Base64Data))) } if v.MIMEType != "" { parts = append(parts, fmt.Sprintf("mime=%s", v.MIMEType)) } if len(v.Extra) > 0 { parts = append(parts, fmt.Sprintf("extra=%v", v.Extra)) } } if len(parts) == 0 { return "" } return strings.Join(parts, ", ") } func formatMessageInputFile(file *MessageInputFile) string { if file == nil { return "" } var parts []string if file.URL != nil { parts = append(parts, fmt.Sprintf("url=%s", *file.URL)) } if file.Base64Data != nil { parts = append(parts, fmt.Sprintf("base64[%d bytes]", len(*file.Base64Data))) } if file.MIMEType != "" { parts = append(parts, fmt.Sprintf("mime=%s", file.MIMEType)) } if file.Name != "" { parts = append(parts, fmt.Sprintf("name=%s", file.Name)) } if len(file.Extra) > 0 { parts = append(parts, fmt.Sprintf("extra=%v", file.Extra)) } if len(parts) == 0 { return "" } return strings.Join(parts, ", ") } func formatOutputPart(part MessageOutputPart) string { switch part.Type { case ChatMessagePartTypeText: return fmt.Sprintf("text: %s", part.Text) case ChatMessagePartTypeImageURL: return fmt.Sprintf("image: %s", formatMessageOutputMedia(part.Image)) case ChatMessagePartTypeAudioURL: return fmt.Sprintf("audio: %s", formatMessageOutputMedia(part.Audio)) case ChatMessagePartTypeVideoURL: return fmt.Sprintf("video: %s", formatMessageOutputMedia(part.Video)) default: return fmt.Sprintf("unknown type: %s", part.Type) } } func formatMessageOutputMedia[T MessageOutputImage | MessageOutputAudio | MessageOutputVideo](media *T) string { if media == nil { return "" } var parts []string switch v := any(media).(type) { case *MessageOutputImage: if v.URL != nil { parts = append(parts, fmt.Sprintf("url=%s", *v.URL)) } if v.Base64Data != nil { parts = append(parts, fmt.Sprintf("base64[%d bytes]", len(*v.Base64Data))) } if v.MIMEType != "" { parts = append(parts, fmt.Sprintf("mime=%s", v.MIMEType)) } if len(v.Extra) > 0 { parts = append(parts, fmt.Sprintf("extra=%v", v.Extra)) } case *MessageOutputAudio: if v.URL != nil { parts = append(parts, fmt.Sprintf("url=%s", *v.URL)) } if v.Base64Data != nil { parts = append(parts, fmt.Sprintf("base64[%d bytes]", len(*v.Base64Data))) } if v.MIMEType != "" { parts = append(parts, fmt.Sprintf("mime=%s", v.MIMEType)) } if len(v.Extra) > 0 { parts = append(parts, fmt.Sprintf("extra=%v", v.Extra)) } case *MessageOutputVideo: if v.URL != nil { parts = append(parts, fmt.Sprintf("url=%s", *v.URL)) } if v.Base64Data != nil { parts = append(parts, fmt.Sprintf("base64[%d bytes]", len(*v.Base64Data))) } if v.MIMEType != "" { parts = append(parts, fmt.Sprintf("mime=%s", v.MIMEType)) } if len(v.Extra) > 0 { parts = append(parts, fmt.Sprintf("extra=%v", v.Extra)) } } if len(parts) == 0 { return "" } return strings.Join(parts, ", ") } func formatChatMessagePart(part ChatMessagePart) string { switch part.Type { case ChatMessagePartTypeText: return fmt.Sprintf("text: %s", part.Text) case ChatMessagePartTypeImageURL: if part.ImageURL != nil { return fmt.Sprintf("image_url: %s", part.ImageURL.URL) } return "image_url: " case ChatMessagePartTypeAudioURL: if part.AudioURL != nil { return fmt.Sprintf("audio_url: %s", part.AudioURL.URL) } return "audio_url: " case ChatMessagePartTypeVideoURL: if part.VideoURL != nil { return fmt.Sprintf("video_url: %s", part.VideoURL.URL) } return "video_url: " case ChatMessagePartTypeFileURL: if part.FileURL != nil { return fmt.Sprintf("file_url: %s", part.FileURL.URL) } return "file_url: " default: return fmt.Sprintf("unknown type: %s", part.Type) } } // SystemMessage represents a message with Role "system". func SystemMessage(content string) *Message { return &Message{ Role: System, Content: content, } } // AssistantMessage represents a message with Role "assistant". func AssistantMessage(content string, toolCalls []ToolCall) *Message { return &Message{ Role: Assistant, Content: content, ToolCalls: toolCalls, } } // UserMessage represents a message with Role "user". func UserMessage(content string) *Message { return &Message{ Role: User, Content: content, } } type toolMessageOptions struct { toolName string } // ToolMessageOption defines a option for ToolMessage type ToolMessageOption func(*toolMessageOptions) // WithToolName returns a ToolMessageOption that sets the tool call name. func WithToolName(name string) ToolMessageOption { return func(o *toolMessageOptions) { o.toolName = name } } // ToolMessage represents a message with Role "tool". func ToolMessage(content string, toolCallID string, opts ...ToolMessageOption) *Message { o := &toolMessageOptions{} for _, opt := range opts { opt(o) } return &Message{ Role: Tool, Content: content, ToolCallID: toolCallID, ToolName: o.toolName, } } // ConcatToolResults merges multiple ToolResult chunks into a single ToolResult. // It collects all ToolOutputParts from the input chunks and merges contiguous text parts within each chunk. // // Merge rules: // - Text parts: Contiguous text parts within each chunk are concatenated into a single text part. // - Non-text parts (image, audio, video, file): These parts are kept as-is without merging. // Each non-text part type can only appear in one chunk; if the same non-text type appears // in multiple chunks, an error is returned. // // This function is primarily used in streaming scenarios where tool output is delivered // in multiple chunks that need to be merged into a complete result. // // Parameters: // - chunks: A slice of ToolResult pointers representing sequential chunks from a stream. // Nil chunks and chunks with empty Parts are safely ignored. // // Returns: // - *ToolResult: The merged ToolResult containing all content from the chunks. // Returns an empty ToolResult if chunks is empty or all chunks are nil/empty. // - error: An error if the same non-text part type appears in multiple chunks. func ConcatToolResults(chunks []*ToolResult) (*ToolResult, error) { if len(chunks) == 0 { return &ToolResult{}, nil } nonTextPartTypes := make(map[ToolPartType]int) var allParts []ToolOutputPart for chunkIdx, chunk := range chunks { if chunk == nil || len(chunk.Parts) == 0 { continue } for _, part := range chunk.Parts { if part.Type != ToolPartTypeText { if prevChunkIdx, exists := nonTextPartTypes[part.Type]; exists { return nil, fmt.Errorf("conflicting %s parts found in chunk %d and chunk %d: "+ "non-text modality parts cannot appear in multiple chunks", part.Type, prevChunkIdx, chunkIdx) } nonTextPartTypes[part.Type] = chunkIdx } } mergedChunkParts, err := concatToolOutputParts(chunk.Parts) if err != nil { return nil, fmt.Errorf("failed to merge text parts in chunk %d: %w", chunkIdx, err) } allParts = append(allParts, mergedChunkParts...) } if len(allParts) == 0 { return &ToolResult{}, nil } return &ToolResult{Parts: allParts}, nil } func concatToolOutputParts(parts []ToolOutputPart) ([]ToolOutputPart, error) { if len(parts) == 0 { return nil, nil } groups := groupToolOutputParts(parts) merged := make([]ToolOutputPart, 0, len(groups)) for _, group := range groups { if len(group) == 1 { merged = append(merged, group...) continue } switch group[0].Type { case ToolPartTypeText: mergedPart, err := mergeToolTextParts(group) if err != nil { return nil, err } merged = append(merged, mergedPart) default: merged = append(merged, group...) } } return merged, nil } func groupToolOutputParts(parts []ToolOutputPart) [][]ToolOutputPart { groups := make([][]ToolOutputPart, 0) i := 0 for i < len(parts) { if parts[i].Type == ToolPartTypeText { end := i + 1 for end < len(parts) && parts[end].Type == ToolPartTypeText { end++ } groups = append(groups, parts[i:end]) i = end } else { groups = append(groups, parts[i:i+1]) i++ } } return groups } func mergeToolTextParts(group []ToolOutputPart) (ToolOutputPart, error) { var sb strings.Builder extraList := make([]map[string]any, 0, len(group)) for _, part := range group { sb.WriteString(part.Text) if len(part.Extra) > 0 { extraList = append(extraList, part.Extra) } } var mergedExtra map[string]any if len(extraList) > 0 { var err error mergedExtra, err = concatExtra(extraList) if err != nil { return ToolOutputPart{}, fmt.Errorf("failed to concat tool output text part extra: %w", err) } } return ToolOutputPart{ Type: ToolPartTypeText, Text: sb.String(), Extra: mergedExtra, }, nil } func concatToolCalls(chunks []ToolCall) ([]ToolCall, error) { var merged []ToolCall m := make(map[int][]int) for i := range chunks { index := chunks[i].Index if index == nil { merged = append(merged, chunks[i]) } else { m[*index] = append(m[*index], i) } } var args strings.Builder for k, v := range m { index := k toolCall := ToolCall{Index: &index} if len(v) > 0 { toolCall = chunks[v[0]] } args.Reset() toolID, toolType, toolName := "", "", "" // these field will output atomically in any chunk for _, n := range v { chunk := chunks[n] if chunk.ID != "" { if toolID == "" { toolID = chunk.ID } else if toolID != chunk.ID { return nil, fmt.Errorf("cannot concat ToolCalls with different tool id: '%s' '%s'", toolID, chunk.ID) } } if chunk.Type != "" { if toolType == "" { toolType = chunk.Type } else if toolType != chunk.Type { return nil, fmt.Errorf("cannot concat ToolCalls with different tool type: '%s' '%s'", toolType, chunk.Type) } } if chunk.Function.Name != "" { if toolName == "" { toolName = chunk.Function.Name } else if toolName != chunk.Function.Name { return nil, fmt.Errorf("cannot concat ToolCalls with different tool name: '%s' '%s'", toolName, chunk.Function.Name) } } if chunk.Function.Arguments != "" { _, err := args.WriteString(chunk.Function.Arguments) if err != nil { return nil, err } } } toolCall.ID = toolID toolCall.Type = toolType toolCall.Function.Name = toolName toolCall.Function.Arguments = args.String() merged = append(merged, toolCall) } if len(merged) > 1 { sort.SliceStable(merged, func(i, j int) bool { iVal, jVal := merged[i].Index, merged[j].Index if iVal == nil && jVal == nil { return false } else if iVal == nil && jVal != nil { return true } else if iVal != nil && jVal == nil { return false } return *iVal < *jVal }) } return merged, nil } func concatAssistantMultiContent(parts []MessageOutputPart) ([]MessageOutputPart, error) { if len(parts) == 0 { return parts, nil } groups := groupOutputParts(parts) merged := make([]MessageOutputPart, 0, len(groups)) for _, group := range groups { mergedPart, err := mergeOutputPartGroup(group) if err != nil { return nil, err } merged = append(merged, mergedPart) } return merged, nil } func groupOutputParts(parts []MessageOutputPart) [][]MessageOutputPart { if len(parts) == 0 { return nil } groups := make([][]MessageOutputPart, 0) currentGroup := []MessageOutputPart{parts[0]} for i := 1; i < len(parts); i++ { if canMergeOutputParts(currentGroup[0], parts[i]) { currentGroup = append(currentGroup, parts[i]) } else { groups = append(groups, currentGroup) currentGroup = []MessageOutputPart{parts[i]} } } groups = append(groups, currentGroup) return groups } func canMergeOutputParts(current, next MessageOutputPart) bool { if current.Type != next.Type { return false } if !isMergeableOutputPartType(current) { return false } if current.StreamingMeta != nil && next.StreamingMeta != nil { return current.StreamingMeta.Index == next.StreamingMeta.Index } return current.StreamingMeta == nil && next.StreamingMeta == nil } func isMergeableOutputPartType(part MessageOutputPart) bool { switch part.Type { case ChatMessagePartTypeText, ChatMessagePartTypeReasoning: return true case ChatMessagePartTypeAudioURL: return isBase64MessageOutputAudioPart(part) default: return false } } func mergeOutputPartGroup(group []MessageOutputPart) (MessageOutputPart, error) { if len(group) == 0 { return MessageOutputPart{}, nil } if len(group) == 1 { return group[0], nil } first := group[0] switch first.Type { case ChatMessagePartTypeText: return mergeTextParts(group) case ChatMessagePartTypeReasoning: return mergeReasoningParts(group) case ChatMessagePartTypeAudioURL: if isBase64MessageOutputAudioPart(first) { return mergeAudioParts(group) } } return first, nil } func mergeTextParts(group []MessageOutputPart) (MessageOutputPart, error) { var sb strings.Builder extraList := make([]map[string]any, 0, len(group)) for _, part := range group { sb.WriteString(part.Text) if len(part.Extra) > 0 { extraList = append(extraList, part.Extra) } } var mergedExtra map[string]any if len(extraList) > 0 { var err error mergedExtra, err = concatExtra(extraList) if err != nil { return MessageOutputPart{}, fmt.Errorf("failed to concat text part extra: %w", err) } } return MessageOutputPart{ Type: ChatMessagePartTypeText, Text: sb.String(), Extra: mergedExtra, StreamingMeta: group[0].StreamingMeta, }, nil } func mergeReasoningParts(group []MessageOutputPart) (MessageOutputPart, error) { var textBuilder strings.Builder var signature string extraList := make([]map[string]any, 0, len(group)) for _, part := range group { if part.Reasoning != nil { textBuilder.WriteString(part.Reasoning.Text) if part.Reasoning.Signature != "" { signature = part.Reasoning.Signature } } if len(part.Extra) > 0 { extraList = append(extraList, part.Extra) } } var mergedExtra map[string]any if len(extraList) > 0 { var err error mergedExtra, err = concatExtra(extraList) if err != nil { return MessageOutputPart{}, fmt.Errorf("failed to concat reasoning part extra: %w", err) } } return MessageOutputPart{ Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{ Text: textBuilder.String(), Signature: signature, }, Extra: mergedExtra, StreamingMeta: group[0].StreamingMeta, }, nil } func mergeAudioParts(group []MessageOutputPart) (MessageOutputPart, error) { var b64Builder strings.Builder var mimeType string audioExtraList := make([]map[string]any, 0, len(group)) partExtraList := make([]map[string]any, 0, len(group)) for _, part := range group { audioPart := part.Audio if audioPart.Base64Data != nil { b64Builder.WriteString(*audioPart.Base64Data) } if mimeType == "" { mimeType = audioPart.MIMEType } if len(audioPart.Extra) > 0 { audioExtraList = append(audioExtraList, audioPart.Extra) } if len(part.Extra) > 0 { partExtraList = append(partExtraList, part.Extra) } } var mergedAudioExtra map[string]any var err error if len(audioExtraList) > 0 { mergedAudioExtra, err = concatExtra(audioExtraList) if err != nil { return MessageOutputPart{}, fmt.Errorf("failed to concat audio extra: %w", err) } } var mergedPartExtra map[string]any if len(partExtraList) > 0 { mergedPartExtra, err = concatExtra(partExtraList) if err != nil { return MessageOutputPart{}, fmt.Errorf("failed to concat audio part extra: %w", err) } } mergedB64 := b64Builder.String() return MessageOutputPart{ Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &mergedB64, MIMEType: mimeType, Extra: mergedAudioExtra, }, }, Extra: mergedPartExtra, StreamingMeta: group[0].StreamingMeta, }, nil } func isBase64MessageOutputAudioPart(part MessageOutputPart) bool { return part.Type == ChatMessagePartTypeAudioURL && part.Audio != nil && part.Audio.Base64Data != nil && part.Audio.URL == nil } func concatUserMultiContent(parts []MessageInputPart) ([]MessageInputPart, error) { if len(parts) == 0 { return parts, nil } merged := make([]MessageInputPart, 0, len(parts)) i := 0 for i < len(parts) { currentPart := parts[i] if currentPart.Type == ChatMessagePartTypeText { end := i + 1 for end < len(parts) && parts[end].Type == ChatMessagePartTypeText { end++ } if end == i+1 { merged = append(merged, currentPart) } else { var sb strings.Builder for k := i; k < end; k++ { sb.WriteString(parts[k].Text) } mergedPart := MessageInputPart{ Type: ChatMessagePartTypeText, Text: sb.String(), } merged = append(merged, mergedPart) } i = end } else { merged = append(merged, currentPart) i++ } } return merged, nil } func concatExtra(extraList []map[string]any) (map[string]any, error) { if len(extraList) == 1 { return generic.CopyMap(extraList[0]), nil } return internal.ConcatItems(extraList) } // ConcatMessages concat messages with the same role and name. // It will concat tool calls with the same index. // It will return an error if the messages have different roles or names. // It's useful for concatenating messages from a stream. // e.g. // // msgs := []*Message{} // for { // msg, err := stream.Recv() // if errors.Is(err, io.EOF) { // break // } // if err != nil {...} // msgs = append(msgs, msg) // } // // concatedMsg, err := ConcatMessages(msgs) // concatedMsg.Content will be full content of all messages func ConcatMessages(msgs []*Message) (*Message, error) { var ( contents []string contentLen int reasoningContents []string reasoningContentLen int toolCalls []ToolCall multiContentParts []ChatMessagePart assistantGenMultiContentParts []MessageOutputPart userInputMultiContentParts []MessageInputPart ret = Message{} extraList = make([]map[string]any, 0, len(msgs)) ) for idx, msg := range msgs { if msg == nil { return nil, fmt.Errorf("unexpected nil chunk in message stream, index: %d", idx) } if msg.Role != "" { if ret.Role == "" { ret.Role = msg.Role } else if ret.Role != msg.Role { return nil, fmt.Errorf("cannot concat messages with "+ "different roles: '%s' '%s'", ret.Role, msg.Role) } } if msg.Name != "" { if ret.Name == "" { ret.Name = msg.Name } else if ret.Name != msg.Name { return nil, fmt.Errorf("cannot concat messages with"+ " different names: '%s' '%s'", ret.Name, msg.Name) } } if msg.ToolCallID != "" { if ret.ToolCallID == "" { ret.ToolCallID = msg.ToolCallID } else if ret.ToolCallID != msg.ToolCallID { return nil, fmt.Errorf("cannot concat messages with"+ " different toolCallIDs: '%s' '%s'", ret.ToolCallID, msg.ToolCallID) } } if msg.ToolName != "" { if ret.ToolName == "" { ret.ToolName = msg.ToolName } else if ret.ToolName != msg.ToolName { return nil, fmt.Errorf("cannot concat messages with"+ " different toolNames: '%s' '%s'", ret.ToolCallID, msg.ToolCallID) } } if msg.Content != "" { contents = append(contents, msg.Content) contentLen += len(msg.Content) } if msg.ReasoningContent != "" { reasoningContents = append(reasoningContents, msg.ReasoningContent) reasoningContentLen += len(msg.ReasoningContent) } if len(msg.ToolCalls) > 0 { toolCalls = append(toolCalls, msg.ToolCalls...) } if len(msg.Extra) > 0 { extraList = append(extraList, msg.Extra) } // The 'MultiContent' field is deprecated but is kept for backward compatibility. if len(msg.MultiContent) > 0 { multiContentParts = append(multiContentParts, msg.MultiContent...) } if len(msg.AssistantGenMultiContent) > 0 { assistantGenMultiContentParts = append(assistantGenMultiContentParts, msg.AssistantGenMultiContent...) } if len(msg.UserInputMultiContent) > 0 { userInputMultiContentParts = append(userInputMultiContentParts, msg.UserInputMultiContent...) } if msg.ResponseMeta != nil && ret.ResponseMeta == nil { ret.ResponseMeta = &ResponseMeta{} } if msg.ResponseMeta != nil && ret.ResponseMeta != nil { // keep the last FinishReason with a valid value. if msg.ResponseMeta.FinishReason != "" { ret.ResponseMeta.FinishReason = msg.ResponseMeta.FinishReason } if msg.ResponseMeta.Usage != nil { if ret.ResponseMeta.Usage == nil { ret.ResponseMeta.Usage = &TokenUsage{} } if msg.ResponseMeta.Usage.PromptTokens > ret.ResponseMeta.Usage.PromptTokens { ret.ResponseMeta.Usage.PromptTokens = msg.ResponseMeta.Usage.PromptTokens } if msg.ResponseMeta.Usage.CompletionTokens > ret.ResponseMeta.Usage.CompletionTokens { ret.ResponseMeta.Usage.CompletionTokens = msg.ResponseMeta.Usage.CompletionTokens } if msg.ResponseMeta.Usage.TotalTokens > ret.ResponseMeta.Usage.TotalTokens { ret.ResponseMeta.Usage.TotalTokens = msg.ResponseMeta.Usage.TotalTokens } if msg.ResponseMeta.Usage.PromptTokenDetails.CachedTokens > ret.ResponseMeta.Usage.PromptTokenDetails.CachedTokens { ret.ResponseMeta.Usage.PromptTokenDetails.CachedTokens = msg.ResponseMeta.Usage.PromptTokenDetails.CachedTokens } if msg.ResponseMeta.Usage.CompletionTokensDetails.ReasoningTokens > ret.ResponseMeta.Usage.CompletionTokensDetails.ReasoningTokens { ret.ResponseMeta.Usage.CompletionTokensDetails.ReasoningTokens = msg.ResponseMeta.Usage.CompletionTokensDetails.ReasoningTokens } } if msg.ResponseMeta.LogProbs != nil { if ret.ResponseMeta.LogProbs == nil { ret.ResponseMeta.LogProbs = &LogProbs{} } ret.ResponseMeta.LogProbs.Content = append(ret.ResponseMeta.LogProbs.Content, msg.ResponseMeta.LogProbs.Content...) } } } if len(contents) > 0 { var sb strings.Builder sb.Grow(contentLen) for _, content := range contents { _, err := sb.WriteString(content) if err != nil { return nil, err } } ret.Content = sb.String() } if len(reasoningContents) > 0 { var sb strings.Builder sb.Grow(reasoningContentLen) for _, rc := range reasoningContents { _, err := sb.WriteString(rc) if err != nil { return nil, err } } ret.ReasoningContent = sb.String() } if len(toolCalls) > 0 { merged, err := concatToolCalls(toolCalls) if err != nil { return nil, err } ret.ToolCalls = merged } if len(extraList) > 0 { extra, err := concatExtra(extraList) if err != nil { return nil, fmt.Errorf("failed to concat message's extra: %w", err) } if len(extra) > 0 { ret.Extra = extra } } if len(multiContentParts) > 0 { ret.MultiContent = multiContentParts } if len(assistantGenMultiContentParts) > 0 { merged, err := concatAssistantMultiContent(assistantGenMultiContentParts) if err != nil { return nil, fmt.Errorf("failed to concat message's assistant multicontent: %w", err) } ret.AssistantGenMultiContent = merged } if len(userInputMultiContentParts) > 0 { merged, err := concatUserMultiContent(userInputMultiContentParts) if err != nil { return nil, fmt.Errorf("failed to concat message's user multicontent: %w", err) } ret.UserInputMultiContent = merged } return &ret, nil } // ConcatMessageStream drains a stream of messages and returns a single // concatenated message representing the merged content. func ConcatMessageStream(s *StreamReader[*Message]) (*Message, error) { defer s.Close() var msgs []*Message for { msg, err := s.Recv() if err != nil { if err == io.EOF { break } return nil, err } msgs = append(msgs, msg) } return ConcatMessages(msgs) } // custom jinja env var jinjaEnvOnce sync.Once var jinjaEnv *gonja.Environment var envInitErr error const ( jinjaInclude = "include" jinjaExtends = "extends" jinjaImport = "import" jinjaFrom = "from" ) func getJinjaEnv() (*gonja.Environment, error) { jinjaEnvOnce.Do(func() { jinjaEnv = gonja.NewEnvironment(config.DefaultConfig, gonja.DefaultLoader) formatInitError := "init jinja env fail: %w" var err error if jinjaEnv.Statements.Exists(jinjaInclude) { err = jinjaEnv.Statements.Replace(jinjaInclude, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { return nil, fmt.Errorf("keyword[include] has been disabled") }) if err != nil { envInitErr = fmt.Errorf(formatInitError, err) return } } if jinjaEnv.Statements.Exists(jinjaExtends) { err = jinjaEnv.Statements.Replace(jinjaExtends, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { return nil, fmt.Errorf("keyword[extends] has been disabled") }) if err != nil { envInitErr = fmt.Errorf(formatInitError, err) return } } if jinjaEnv.Statements.Exists(jinjaFrom) { err = jinjaEnv.Statements.Replace(jinjaFrom, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { return nil, fmt.Errorf("keyword[from] has been disabled") }) if err != nil { envInitErr = fmt.Errorf(formatInitError, err) return } } if jinjaEnv.Statements.Exists(jinjaImport) { err = jinjaEnv.Statements.Replace(jinjaImport, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { return nil, fmt.Errorf("keyword[import] has been disabled") }) if err != nil { envInitErr = fmt.Errorf(formatInitError, err) return } } }) return jinjaEnv, envInitErr } ================================================ FILE: schema/message_parser.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "context" "fmt" "strings" "github.com/bytedance/sonic" ) // MessageParser parses a Message into a strongly typed value. type MessageParser[T any] interface { Parse(ctx context.Context, m *Message) (T, error) } // MessageParseFrom determines the source of the data to be parsed. default is content (Message.Content). type MessageParseFrom string // MessageParseFrom indicates the source data used by the parser. const ( MessageParseFromContent MessageParseFrom = "content" MessageParseFromToolCall MessageParseFrom = "tool_call" ) // MessageJSONParseConfig configures JSON parsing behavior for Message. type MessageJSONParseConfig struct { // parse from content or tool call, default is content. ParseFrom MessageParseFrom `json:"parse_from,omitempty"` // parse key path, default is empty. // must be a valid json path expression, eg: field.sub_field ParseKeyPath string `json:"parse_key_path,omitempty"` } // NewMessageJSONParser creates a new MessageJSONParser. func NewMessageJSONParser[T any](config *MessageJSONParseConfig) MessageParser[T] { if config == nil { config = &MessageJSONParseConfig{} } if config.ParseFrom == "" { config.ParseFrom = MessageParseFromContent } return &MessageJSONParser[T]{ ParseFrom: config.ParseFrom, ParseKeyPath: config.ParseKeyPath, } } // MessageJSONParser is a parser that parses a message into an object T, using json unmarshal. // eg of parse to single struct: // // config := &MessageJSONParseConfig{ // ParseFrom: MessageParseFromToolCall, // } // parser := NewMessageJSONParser[GetUserParam](config) // param, err := parser.Parse(ctx, message) // // eg of parse to slice of struct: // // config := &MessageJSONParseConfig{ // ParseFrom: MessageParseFromToolCall, // } // // parser := NewMessageJSONParser[GetUserParam](config) // param, err := parser.Parse(ctx, message) type MessageJSONParser[T any] struct { ParseFrom MessageParseFrom ParseKeyPath string } // Parse parses a message into an object T. func (p *MessageJSONParser[T]) Parse(ctx context.Context, m *Message) (parsed T, err error) { if p.ParseFrom == MessageParseFromContent { return p.parse(m.Content) } else if p.ParseFrom == MessageParseFromToolCall { if len(m.ToolCalls) == 0 { return parsed, fmt.Errorf("no tool call found") } return p.parse(m.ToolCalls[0].Function.Arguments) } return parsed, fmt.Errorf("invalid parse from type: %s", p.ParseFrom) } // extractData extracts data from a string using the parse key path. func (p *MessageJSONParser[T]) extractData(data string) (string, error) { if p.ParseKeyPath == "" { return data, nil } keys := strings.Split(p.ParseKeyPath, ".") interfaceKeys := make([]any, len(keys)) for i, key := range keys { interfaceKeys[i] = key } node, err := sonic.GetFromString(data, interfaceKeys...) if err != nil { return "", fmt.Errorf("failed to get parse key path: %w", err) } bytes, err := node.MarshalJSON() if err != nil { return "", fmt.Errorf("failed to marshal node: %w", err) } return string(bytes), nil } // parse parses a string into an object T. func (p *MessageJSONParser[T]) parse(data string) (parsed T, err error) { parsedData, err := p.extractData(data) if err != nil { return parsed, err } if err := sonic.UnmarshalString(parsedData, &parsed); err != nil { return parsed, fmt.Errorf("failed to unmarshal content: %w", err) } return parsed, nil } ================================================ FILE: schema/message_parser_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "context" "testing" "github.com/stretchr/testify/assert" ) type TestStructForParse struct { ID int `json:"id"` Name string `json:"name"` XX struct { YY int `json:"yy"` } `json:"xx"` } func TestMessageJSONParser(t *testing.T) { ctx := context.Background() t.Run("parse from content", func(t *testing.T) { parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ ParseFrom: MessageParseFromContent, }) parsed, err := parser.Parse(ctx, &Message{ Content: `{"id": 1, "name": "test", "xx": {"yy": 2}}`, }) assert.Nil(t, err) assert.Equal(t, 1, parsed.ID) }) t.Run("parse from tool call", func(t *testing.T) { t.Run("only one tool call, default use first tool call", func(t *testing.T) { parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ ParseFrom: MessageParseFromToolCall, }) parsed, err := parser.Parse(ctx, &Message{ ToolCalls: []ToolCall{ {Function: FunctionCall{Arguments: `{"id": 1, "name": "test", "xx": {"yy": 2}}`}}, }, }) assert.Nil(t, err) assert.Equal(t, 1, parsed.ID) }) t.Run("parse key path", func(t *testing.T) { type TestStructForParse2 struct { YY int `json:"yy"` } parser := NewMessageJSONParser[TestStructForParse2](&MessageJSONParseConfig{ ParseFrom: MessageParseFromToolCall, ParseKeyPath: "xx", }) parsed, err := parser.Parse(ctx, &Message{ ToolCalls: []ToolCall{ {Function: FunctionCall{Arguments: `{"id": 1, "name": "test", "xx": {"yy": 2}}`}}, }, }) assert.Nil(t, err) assert.Equal(t, 2, parsed.YY) }) t.Run("parse key path, deep level", func(t *testing.T) { type TestStructForParse3 struct { ZZ int `json:"zz"` } parser := NewMessageJSONParser[TestStructForParse3](&MessageJSONParseConfig{ ParseFrom: MessageParseFromToolCall, ParseKeyPath: "xx.yy", }) parsed, err := parser.Parse(ctx, &Message{ ToolCalls: []ToolCall{ {Function: FunctionCall{Arguments: `{"id": 1, "name": "test", "xx": {"yy": {"zz": 3}}}`}}, }, }) assert.Nil(t, err) assert.Equal(t, 3, parsed.ZZ) }) t.Run("parse key with pointer", func(t *testing.T) { type TestStructForParse4 struct { ZZ *int `json:"zz"` } parser := NewMessageJSONParser[**TestStructForParse4](&MessageJSONParseConfig{ ParseFrom: MessageParseFromToolCall, }) parsed, err := parser.Parse(ctx, &Message{ ToolCalls: []ToolCall{{Function: FunctionCall{Arguments: `{"zz": 3}`}}}, }) assert.Nil(t, err) assert.Equal(t, 3, *((**parsed).ZZ)) }) }) t.Run("parse of slice", func(t *testing.T) { t.Run("valid slice string, not multiple tool calls", func(t *testing.T) { parser := NewMessageJSONParser[[]map[string]any](&MessageJSONParseConfig{ ParseFrom: MessageParseFromToolCall, }) parsed, err := parser.Parse(ctx, &Message{ ToolCalls: []ToolCall{{Function: FunctionCall{Arguments: `[{"id": 1}, {"id": 2}]`}}}, }) assert.Nil(t, err) assert.Equal(t, 2, len(parsed)) }) t.Run("invalid slice string, not multiple tool calls", func(t *testing.T) { parser := NewMessageJSONParser[[]map[string]any](&MessageJSONParseConfig{ ParseFrom: MessageParseFromToolCall, }) _, err := parser.Parse(ctx, &Message{ ToolCalls: []ToolCall{ {Function: FunctionCall{Arguments: `{"id": 1}`}}, {Function: FunctionCall{Arguments: `{"id": 2}`}}, }, }) assert.NotNil(t, err) }) }) t.Run("invalid configs", func(t *testing.T) { parser := NewMessageJSONParser[TestStructForParse](nil) _, err := parser.Parse(ctx, &Message{ Content: "", }) assert.NotNil(t, err) }) t.Run("invalid parse key path", func(t *testing.T) { parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ ParseKeyPath: "...invalid", }) _, err := parser.Parse(ctx, &Message{}) assert.NotNil(t, err) }) t.Run("invalid parse from", func(t *testing.T) { parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ ParseFrom: "invalid", }) _, err := parser.Parse(ctx, &Message{}) assert.NotNil(t, err) }) t.Run("invalid parse from type", func(t *testing.T) { parser := NewMessageJSONParser[int](&MessageJSONParseConfig{ ParseFrom: MessageParseFrom("invalid"), }) _, err := parser.Parse(ctx, &Message{}) assert.NotNil(t, err) }) } ================================================ FILE: schema/message_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "context" "reflect" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/internal/generic" ) func TestMessageTemplate(t *testing.T) { pyFmtMessage := UserMessage("input: {question}") jinja2Message := UserMessage("input: {{question}}") goTemplateMessage := UserMessage("input: {{.question}}") ctx := context.Background() question := "what's the weather today" expected := []*Message{UserMessage("input: " + question)} ms, err := pyFmtMessage.Format(ctx, map[string]any{"question": question}, FString) assert.Nil(t, err) assert.True(t, reflect.DeepEqual(expected, ms)) ms, err = jinja2Message.Format(ctx, map[string]any{"question": question}, Jinja2) assert.Nil(t, err) assert.True(t, reflect.DeepEqual(expected, ms)) ms, err = goTemplateMessage.Format(ctx, map[string]any{"question": question}, GoTemplate) assert.Nil(t, err) assert.True(t, reflect.DeepEqual(expected, ms)) mp := MessagesPlaceholder("chat_history", false) m1 := UserMessage("how are you?") m2 := AssistantMessage("I'm good. how about you?", nil) ms, err = mp.Format(ctx, map[string]any{"chat_history": []*Message{m1, m2}}, FString) assert.Nil(t, err) // len(ms) == 2 assert.Equal(t, 2, len(ms)) assert.Equal(t, ms[0], m1) assert.Equal(t, ms[1], m2) } func TestConcatMessage(t *testing.T) { t.Run("tool_call_normal_append", func(t *testing.T) { expectMsg := &Message{ Role: "assistant", Content: "", ToolCalls: []ToolCall{ { Index: generic.PtrOf(0), ID: "i_am_a_too_call_id", Type: "function", Function: FunctionCall{ Name: "i_am_a_tool_name", Arguments: "{}", }, }, }, } givenMsgList := []*Message{ { Role: "", Content: "", ToolCalls: []ToolCall{ { Index: generic.PtrOf(0), ID: "", Type: "", Function: FunctionCall{ Name: "", }, }, }, }, { Role: "assistant", Content: "", ToolCalls: []ToolCall{ { Index: generic.PtrOf(0), ID: "i_am_a_too_call_id", Type: "function", Function: FunctionCall{ Name: "i_am_a_tool_name", }, }, }, }, { Role: "", Content: "", ToolCalls: []ToolCall{ { Index: generic.PtrOf(0), ID: "", Type: "", Function: FunctionCall{ Name: "", Arguments: "{}", }, }, }, }, } msg, err := ConcatMessages(givenMsgList) assert.NoError(t, err) assert.EqualValues(t, expectMsg, msg) }) t.Run("exist_nil_message", func(t *testing.T) { givenMsgList := []*Message{ nil, { Role: "assistant", Content: "", ToolCalls: []ToolCall{ { Index: generic.PtrOf(0), ID: "i_am_a_too_call_id", Type: "function", Function: FunctionCall{ Name: "i_am_a_tool_name", }, }, }, }, } _, err := ConcatMessages(givenMsgList) assert.ErrorContains(t, err, "unexpected nil chunk in message stream") }) t.Run("response_meta", func(t *testing.T) { expectedMsg := &Message{ Role: "assistant", ResponseMeta: &ResponseMeta{ FinishReason: "stop", Usage: &TokenUsage{ CompletionTokens: 15, PromptTokens: 30, PromptTokenDetails: PromptTokenDetails{ CachedTokens: 15, }, CompletionTokensDetails: CompletionTokensDetails{ ReasoningTokens: 8, }, TotalTokens: 45, }, }, } givenMsgList := []*Message{ { Role: "assistant", }, { Role: "assistant", ResponseMeta: &ResponseMeta{ FinishReason: "", Usage: &TokenUsage{ CompletionTokens: 10, PromptTokens: 20, PromptTokenDetails: PromptTokenDetails{ CachedTokens: 10, }, CompletionTokensDetails: CompletionTokensDetails{ ReasoningTokens: 5, }, TotalTokens: 30, }, }, }, { Role: "assistant", ResponseMeta: &ResponseMeta{ FinishReason: "stop", }, }, { Role: "assistant", ResponseMeta: &ResponseMeta{ Usage: &TokenUsage{ CompletionTokens: 15, PromptTokens: 30, PromptTokenDetails: PromptTokenDetails{ CachedTokens: 15, }, CompletionTokensDetails: CompletionTokensDetails{ ReasoningTokens: 8, }, TotalTokens: 45, }, }, }, } msg, err := ConcatMessages(givenMsgList) assert.NoError(t, err) assert.Equal(t, expectedMsg, msg) givenMsgList = append(givenMsgList, &Message{ Role: "assistant", ResponseMeta: &ResponseMeta{ FinishReason: "tool_calls", }, }) msg, err = ConcatMessages(givenMsgList) assert.NoError(t, err) expectedMsg.ResponseMeta.FinishReason = "tool_calls" assert.Equal(t, expectedMsg, msg) }) t.Run("err: different roles", func(t *testing.T) { msgs := []*Message{ {Role: User}, {Role: Assistant}, } msg, err := ConcatMessages(msgs) if assert.Error(t, err) { assert.ErrorContains(t, err, "cannot concat messages with different roles") assert.Nil(t, msg) } }) t.Run("err: different name", func(t *testing.T) { msgs := []*Message{ {Role: Assistant, Name: "n", Content: "1"}, {Role: Assistant, Name: "a", Content: "2"}, } msg, err := ConcatMessages(msgs) if assert.Error(t, err) { assert.ErrorContains(t, err, "cannot concat messages with different names") assert.Nil(t, msg) } }) t.Run("err: different tool name", func(t *testing.T) { msgs := []*Message{ { Role: "", Content: "", ToolCallID: "123", ToolCalls: []ToolCall{ { Index: generic.PtrOf(0), ID: "abc", Type: "", Function: FunctionCall{ Name: "", }, }, }, }, { Role: "assistant", Content: "", ToolCallID: "321", ToolCalls: []ToolCall{ { Index: generic.PtrOf(0), ID: "abc", Type: "function", Function: FunctionCall{ Name: "i_am_a_tool_name", }, }, }, }, } msg, err := ConcatMessages(msgs) if assert.Error(t, err) { assert.ErrorContains(t, err, "cannot concat messages with different toolCallIDs") assert.Nil(t, msg) } }) t.Run("first response meta usage is nil", func(t *testing.T) { exp := &Message{ Role: "assistant", ResponseMeta: &ResponseMeta{ FinishReason: "stop", Usage: &TokenUsage{ CompletionTokens: 15, PromptTokens: 30, TotalTokens: 45, }, }, } msgs := []*Message{ { Role: "assistant", ResponseMeta: &ResponseMeta{ FinishReason: "", Usage: nil, }, }, { Role: "assistant", ResponseMeta: &ResponseMeta{ FinishReason: "stop", }, }, { Role: "assistant", ResponseMeta: &ResponseMeta{ Usage: &TokenUsage{ CompletionTokens: 15, PromptTokens: 30, TotalTokens: 45, }, }, }, } msg, err := ConcatMessages(msgs) assert.NoError(t, err) assert.Equal(t, exp, msg) }) t.Run("concurrent concat", func(t *testing.T) { content := "i_am_a_good_concat_message" exp := &Message{Role: Assistant, Content: content} var msgs []*Message for i := 0; i < len(content); i++ { msgs = append(msgs, &Message{Role: Assistant, Content: content[i : i+1]}) } wg := sync.WaitGroup{} size := 100 wg.Add(size) for i := 0; i < size; i++ { go func() { defer wg.Done() msg, err := ConcatMessages(msgs) assert.NoError(t, err) assert.Equal(t, exp, msg) }() } wg.Wait() }) t.Run("concat logprobs", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, Content: "🚀", ResponseMeta: &ResponseMeta{ LogProbs: &LogProbs{ Content: []LogProb{ { Token: "\\xf0\\x9f\\x9a", LogProb: -0.0000073458323, Bytes: []int64{240, 159, 154}, }, { Token: "\\x80", LogProb: 0, Bytes: []int64{128}, }, }, }, }, }, { Role: "", Content: "❤️", ResponseMeta: &ResponseMeta{ LogProbs: &LogProbs{ Content: []LogProb{ { Token: "❤️", LogProb: -0.0011431955, Bytes: []int64{226, 157, 164, 239, 184, 143}, }, }, }, }, }, { Role: "", ResponseMeta: &ResponseMeta{ FinishReason: "stop", Usage: &TokenUsage{ PromptTokens: 7, CompletionTokens: 3, TotalTokens: 10, }, }, }, } msg, err := ConcatMessages(msgs) assert.NoError(t, err) assert.Equal(t, 3, len(msg.ResponseMeta.LogProbs.Content)) assert.Equal(t, msgs[0].ResponseMeta.LogProbs.Content[0], msg.ResponseMeta.LogProbs.Content[0]) assert.Equal(t, msgs[0].ResponseMeta.LogProbs.Content[1], msg.ResponseMeta.LogProbs.Content[1]) assert.Equal(t, msgs[1].ResponseMeta.LogProbs.Content[0], msg.ResponseMeta.LogProbs.Content[2]) }) t.Run("fix unexpected setting ResponseMeta of the first element in slice after ConcatMessages", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, Content: "🚀", //ResponseMeta: &ResponseMeta{}, }, { Role: "", Content: "❤️", ResponseMeta: &ResponseMeta{}, }, { Role: "", ResponseMeta: &ResponseMeta{ FinishReason: "stop", Usage: &TokenUsage{ PromptTokens: 7, CompletionTokens: 3, TotalTokens: 10, }, }, }, } msg, err := ConcatMessages(msgs) assert.NoError(t, err) assert.Equal(t, msgs[2].ResponseMeta, msg.ResponseMeta) assert.Nil(t, msgs[0].ResponseMeta) }) t.Run("concat assistant multi content", func(t *testing.T) { base64Audio1 := "dGVzdF9hdWRpb18x" base64Audio2 := "dGVzdF9hdWRpb18y" imageURL1 := "https://example.com/image1.png" imageURL2 := "https://example.com/image2.png" msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Hello, "}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "world!"}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio1}}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio2, MIMEType: "audio/wav"}}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeImageURL, Image: &MessageOutputImage{MessagePartCommon: MessagePartCommon{URL: &imageURL1}}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeImageURL, Image: &MessageOutputImage{MessagePartCommon: MessagePartCommon{URL: &imageURL2}}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) mergedBase64Audio := base64Audio1 + base64Audio2 expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Hello, world!"}, {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &mergedBase64Audio, MIMEType: "audio/wav"}}}, {Type: ChatMessagePartTypeImageURL, Image: &MessageOutputImage{MessagePartCommon: MessagePartCommon{URL: &imageURL1}}}, {Type: ChatMessagePartTypeImageURL, Image: &MessageOutputImage{MessagePartCommon: MessagePartCommon{URL: &imageURL2}}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat assistant multi content with extra", func(t *testing.T) { base64Audio1 := "dGVzdF9hdWRpb18x" base64Audio2 := "dGVzdF9hdWRpb18y" msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio1, Extra: map[string]any{"key1": "val1"}}}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio2, Extra: map[string]any{"key2": "val2"}}}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) mergedBase64Audio := base64Audio1 + base64Audio2 expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &mergedBase64Audio, Extra: map[string]any{"key1": "val1", "key2": "val2"}}}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat assistant multi content with single extra", func(t *testing.T) { base64Audio1 := "dGVzdF9hdWRpb18x" base64Audio2 := "dGVzdF9hdWRpb18y" msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio1, Extra: map[string]any{"key1": "val1"}}}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio2}}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) mergedBase64Audio := base64Audio1 + base64Audio2 expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &mergedBase64Audio, Extra: map[string]any{"key1": "val1"}}}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat text parts with extra", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Hello ", Extra: map[string]any{"key1": "val1"}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "World", Extra: map[string]any{"key2": "val2"}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Hello World", Extra: map[string]any{"key1": "val1", "key2": "val2"}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat text parts with single extra", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Hello ", Extra: map[string]any{"key1": "val1"}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "World"}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Hello World", Extra: map[string]any{"key1": "val1"}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat reasoning parts with extra", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "First, "}, Extra: map[string]any{"key1": "val1"}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "I need to think."}, Extra: map[string]any{"key2": "val2"}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "First, I need to think."}, Extra: map[string]any{"key1": "val1", "key2": "val2"}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat audio parts with outer extra", func(t *testing.T) { base64Audio1 := "dGVzdF9hdWRpb18x" base64Audio2 := "dGVzdF9hdWRpb18y" msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio1}}, Extra: map[string]any{"outer1": "val1"}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &base64Audio2}}, Extra: map[string]any{"outer2": "val2"}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) mergedBase64Audio := base64Audio1 + base64Audio2 expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{MessagePartCommon: MessagePartCommon{Base64Data: &mergedBase64Audio}}, Extra: map[string]any{"outer1": "val1", "outer2": "val2"}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat reasoning parts", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "First, "}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "I need to think."}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Here is my answer."}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "First, I need to think."}}, {Type: ChatMessagePartTypeText, Text: "Here is my answer."}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat reasoning parts with signature", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "Step 1: "}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "analyze.", Signature: "sig_abc"}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: " Step 2: ", Signature: "sig_xyz"}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "conclude."}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "Step 1: analyze. Step 2: conclude.", Signature: "sig_xyz"}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat with streaming meta index grouping", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "Think "}, StreamingMeta: &MessageStreamingMeta{Index: 0}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "more."}, StreamingMeta: &MessageStreamingMeta{Index: 0}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Hello ", StreamingMeta: &MessageStreamingMeta{Index: 1}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "world!", StreamingMeta: &MessageStreamingMeta{Index: 1}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeReasoning, Reasoning: &MessageOutputReasoning{Text: "Think more."}, StreamingMeta: &MessageStreamingMeta{Index: 0}}, {Type: ChatMessagePartTypeText, Text: "Hello world!", StreamingMeta: &MessageStreamingMeta{Index: 1}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat with different streaming meta index should not merge", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "First block ", StreamingMeta: &MessageStreamingMeta{Index: 0}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Second block ", StreamingMeta: &MessageStreamingMeta{Index: 1}}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "continues.", StreamingMeta: &MessageStreamingMeta{Index: 0}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "First block ", StreamingMeta: &MessageStreamingMeta{Index: 0}}, {Type: ChatMessagePartTypeText, Text: "Second block ", StreamingMeta: &MessageStreamingMeta{Index: 1}}, {Type: ChatMessagePartTypeText, Text: "continues.", StreamingMeta: &MessageStreamingMeta{Index: 0}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat empty parts", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{}, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) assert.Empty(t, mergedMsg.AssistantGenMultiContent) }) t.Run("concat single part no merge needed", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Single"}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Single"}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat multiple consecutive text parts", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "One "}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Two "}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Three "}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Four"}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "One Two Three Four"}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat without streaming meta should not merge with streaming meta parts", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "No meta "}, }, }, { Role: Assistant, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "With meta", StreamingMeta: &MessageStreamingMeta{Index: 0}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedContent := []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "No meta "}, {Type: ChatMessagePartTypeText, Text: "With meta", StreamingMeta: &MessageStreamingMeta{Index: 0}}, } assert.Equal(t, expectedContent, mergedMsg.AssistantGenMultiContent) }) t.Run("concat multi content (deprecated)", func(t *testing.T) { msgs := []*Message{ { Role: Assistant, MultiContent: []ChatMessagePart{ {Type: ChatMessagePartTypeImageURL, ImageURL: &ChatMessageImageURL{URL: "image1.jpg"}}, }, }, { Role: Assistant, MultiContent: []ChatMessagePart{ {Type: ChatMessagePartTypeImageURL, ImageURL: &ChatMessageImageURL{URL: "image2.jpg"}}, }, }, } mergedMsg, err := ConcatMessages(msgs) assert.NoError(t, err) expectedMultiContent := []ChatMessagePart{ {Type: ChatMessagePartTypeImageURL, ImageURL: &ChatMessageImageURL{URL: "image1.jpg"}}, {Type: ChatMessagePartTypeImageURL, ImageURL: &ChatMessageImageURL{URL: "image2.jpg"}}, } assert.Equal(t, expectedMultiContent, mergedMsg.MultiContent) }) } func TestConcatToolCalls(t *testing.T) { t.Run("atomic_field_in_first_chunk", func(t *testing.T) { givenToolCalls := []ToolCall{ { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", }, }, { Index: generic.PtrOf(0), Function: FunctionCall{ Arguments: "call me please", }, }, } expectedToolCall := ToolCall{ Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", Arguments: "call me please", }, } tc, err := concatToolCalls(givenToolCalls) assert.NoError(t, err) assert.Len(t, tc, 1) assert.EqualValues(t, expectedToolCall, tc[0]) }) t.Run("atomic_field_in_every_chunk", func(t *testing.T) { givenToolCalls := []ToolCall{ { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", }, }, { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", Arguments: "call me please", }, }, } expectedToolCall := ToolCall{ Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", Arguments: "call me please", }, } tc, err := concatToolCalls(givenToolCalls) assert.NoError(t, err) assert.Len(t, tc, 1) assert.EqualValues(t, expectedToolCall, tc[0]) }) t.Run("atomic_field_in_interval", func(t *testing.T) { givenToolCalls := []ToolCall{ { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "", Function: FunctionCall{ Name: "", }, }, { Index: generic.PtrOf(0), ID: "", Type: "function", Function: FunctionCall{ Name: "", Arguments: "call me please", }, }, { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "", Function: FunctionCall{ Name: "", Arguments: "", }, }, } expectedToolCall := ToolCall{ Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "", Arguments: "call me please", }, } tc, err := concatToolCalls(givenToolCalls) assert.NoError(t, err) assert.Len(t, tc, 1) assert.EqualValues(t, expectedToolCall, tc[0]) }) t.Run("different_tool_id", func(t *testing.T) { givenToolCalls := []ToolCall{ { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", }, }, { Index: generic.PtrOf(0), ID: "tool_call_id_1", Type: "function", Function: FunctionCall{ Name: "tool_name", Arguments: "call me please", }, }, } _, err := concatToolCalls(givenToolCalls) assert.ErrorContains(t, err, "cannot concat ToolCalls with different tool id") }) t.Run("different_tool_type", func(t *testing.T) { givenToolCalls := []ToolCall{ { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", }, }, { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function_1", Function: FunctionCall{ Name: "tool_name", Arguments: "call me please", }, }, } _, err := concatToolCalls(givenToolCalls) assert.ErrorContains(t, err, "cannot concat ToolCalls with different tool type") }) t.Run("different_tool_name", func(t *testing.T) { givenToolCalls := []ToolCall{ { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name", }, }, { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "tool_name_1", Arguments: "call me please", }, }, } _, err := concatToolCalls(givenToolCalls) assert.ErrorContains(t, err, "cannot concat ToolCalls with different tool name") }) t.Run("multi_tool_call", func(t *testing.T) { givenToolCalls := []ToolCall{ { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "", Function: FunctionCall{ Name: "", }, }, { Index: generic.PtrOf(0), ID: "", Type: "function", Function: FunctionCall{ Name: "", Arguments: "call me please", }, }, { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "", Function: FunctionCall{ Name: "", Arguments: "", }, }, { Index: generic.PtrOf(1), ID: "tool_call_id", Type: "", Function: FunctionCall{ Name: "", }, }, { Index: generic.PtrOf(1), ID: "", Type: "function", Function: FunctionCall{ Name: "", Arguments: "call me please", }, }, { Index: generic.PtrOf(1), ID: "tool_call_id", Type: "", Function: FunctionCall{ Name: "", Arguments: "", }, }, { Index: nil, ID: "22", Type: "", Function: FunctionCall{ Name: "", }, }, { Index: nil, ID: "44", Type: "", Function: FunctionCall{ Name: "", }, }, } expectedToolCall := []ToolCall{ { Index: nil, ID: "22", Type: "", Function: FunctionCall{ Name: "", }, }, { Index: nil, ID: "44", Type: "", Function: FunctionCall{ Name: "", }, }, { Index: generic.PtrOf(0), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "", Arguments: "call me please", }, }, { Index: generic.PtrOf(1), ID: "tool_call_id", Type: "function", Function: FunctionCall{ Name: "", Arguments: "call me please", }, }, } tc, err := concatToolCalls(givenToolCalls) assert.NoError(t, err) assert.EqualValues(t, expectedToolCall, tc) }) } func TestFormatMultiContent(t *testing.T) { vs := map[string]any{ "name": "eino", "url": "https://example.com/img.png", "id": "42", } t.Run("empty input", func(t *testing.T) { out, err := formatMultiContent(nil, vs, FString) assert.NoError(t, err) assert.Equal(t, []ChatMessagePart{}, out) }) t.Run("render text and urls with FString", func(t *testing.T) { in := []ChatMessagePart{ {Type: ChatMessagePartTypeText, Text: "hello {name}"}, {Type: ChatMessagePartTypeImageURL, ImageURL: &ChatMessageImageURL{URL: "{url}"}}, {Type: ChatMessagePartTypeAudioURL, AudioURL: &ChatMessageAudioURL{URL: "http://audio/{id}.wav"}}, {Type: ChatMessagePartTypeVideoURL, VideoURL: &ChatMessageVideoURL{URL: "http://video/{id}.mp4"}}, {Type: ChatMessagePartTypeFileURL, FileURL: &ChatMessageFileURL{URL: "http://file/{id}.txt"}}, } out, err := formatMultiContent(in, vs, FString) assert.NoError(t, err) if assert.Len(t, out, len(in)) { assert.Equal(t, "hello eino", out[0].Text) assert.Equal(t, "https://example.com/img.png", out[1].ImageURL.URL) assert.Equal(t, "http://audio/42.wav", out[2].AudioURL.URL) assert.Equal(t, "http://video/42.mp4", out[3].VideoURL.URL) assert.Equal(t, "http://file/42.txt", out[4].FileURL.URL) } }) t.Run("nil nested pointer should be skipped", func(t *testing.T) { in := []ChatMessagePart{ {Type: ChatMessagePartTypeImageURL, ImageURL: nil}, {Type: ChatMessagePartTypeAudioURL, AudioURL: nil}, {Type: ChatMessagePartTypeVideoURL, VideoURL: nil}, {Type: ChatMessagePartTypeFileURL, FileURL: nil}, } out, err := formatMultiContent(in, vs, FString) assert.NoError(t, err) assert.Equal(t, in, out) }) t.Run("missing var should error in GoTemplate", func(t *testing.T) { in := []ChatMessagePart{{Type: ChatMessagePartTypeText, Text: "hi {{.who}}"}} _, err := formatMultiContent(in, map[string]any{"name": "x"}, GoTemplate) assert.Error(t, err) }) } func TestFormatUserInputMultiContent(t *testing.T) { makeStrPtr := func(s string) *string { return &s } vs := map[string]any{ "x": "world", "img": "https://example.com/i.png", "b64": "YmFzZTY0", "aid": "99", "vid": "77", "file": "abc", } t.Run("empty input", func(t *testing.T) { out, err := formatUserInputMultiContent(nil, vs, FString) assert.NoError(t, err) assert.Equal(t, []MessageInputPart{}, out) }) t.Run("render text and both URL/Base64 for each type", func(t *testing.T) { in := []MessageInputPart{ {Type: ChatMessagePartTypeText, Text: "hello {x}"}, {Type: ChatMessagePartTypeImageURL, Image: &MessageInputImage{MessagePartCommon: MessagePartCommon{URL: makeStrPtr("{img}"), Base64Data: makeStrPtr("{b64}")}}}, {Type: ChatMessagePartTypeAudioURL, Audio: &MessageInputAudio{MessagePartCommon: MessagePartCommon{URL: makeStrPtr("http://a/{aid}.wav"), Base64Data: makeStrPtr("{b64}")}}}, {Type: ChatMessagePartTypeVideoURL, Video: &MessageInputVideo{MessagePartCommon: MessagePartCommon{URL: makeStrPtr("http://v/{vid}.mp4"), Base64Data: makeStrPtr("{b64}")}}}, {Type: ChatMessagePartTypeFileURL, File: &MessageInputFile{MessagePartCommon: MessagePartCommon{URL: makeStrPtr("/f/{file}.txt"), Base64Data: makeStrPtr("{b64}")}}}, } out, err := formatUserInputMultiContent(in, vs, FString) assert.NoError(t, err) if assert.Len(t, out, len(in)) { assert.Equal(t, "hello world", out[0].Text) assert.Equal(t, "https://example.com/i.png", *out[1].Image.URL) assert.Equal(t, "YmFzZTY0", *out[1].Image.Base64Data) assert.Equal(t, "http://a/99.wav", *out[2].Audio.URL) assert.Equal(t, "YmFzZTY0", *out[2].Audio.Base64Data) assert.Equal(t, "http://v/77.mp4", *out[3].Video.URL) assert.Equal(t, "YmFzZTY0", *out[3].Video.Base64Data) assert.Equal(t, "/f/abc.txt", *out[4].File.URL) assert.Equal(t, "YmFzZTY0", *out[4].File.Base64Data) } }) t.Run("empty string pointer should not be formatted", func(t *testing.T) { empty := "" in := []MessageInputPart{ {Type: ChatMessagePartTypeImageURL, Image: &MessageInputImage{MessagePartCommon: MessagePartCommon{URL: &empty, Base64Data: &empty}}}, } out, err := formatUserInputMultiContent(in, vs, FString) assert.NoError(t, err) if assert.Len(t, out, 1) { assert.NotNil(t, out[0].Image.URL) assert.NotNil(t, out[0].Image.Base64Data) assert.Equal(t, "", *out[0].Image.URL) assert.Equal(t, "", *out[0].Image.Base64Data) } }) } func TestConcatToolResults(t *testing.T) { t.Run("empty_chunks", func(t *testing.T) { result, err := ConcatToolResults([]*ToolResult{}) assert.NoError(t, err) assert.NotNil(t, result) assert.Empty(t, result.Parts) }) t.Run("nil_chunks", func(t *testing.T) { result, err := ConcatToolResults([]*ToolResult{nil, nil}) assert.NoError(t, err) assert.NotNil(t, result) assert.Empty(t, result.Parts) }) t.Run("single_text_part", func(t *testing.T) { chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Hello World"}, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 1) assert.Equal(t, ToolPartTypeText, result.Parts[0].Type) assert.Equal(t, "Hello World", result.Parts[0].Text) }) t.Run("multiple_text_parts_merge", func(t *testing.T) { chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Hello "}, }, }, { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "World"}, }, }, { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "!"}, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 3) }) t.Run("multiple_text_parts_merge_with_extra", func(t *testing.T) { chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Hello ", Extra: map[string]any{"k1": "v1"}}, {Type: ToolPartTypeText, Text: "World", Extra: map[string]any{"k2": "v2"}}, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 1) assert.Equal(t, "Hello World", result.Parts[0].Text) assert.Equal(t, map[string]any{"k1": "v1", "k2": "v2"}, result.Parts[0].Extra) }) t.Run("multiple_text_parts_merge_with_single_extra", func(t *testing.T) { chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Hello ", Extra: map[string]any{"k1": "v1"}}, {Type: ToolPartTypeText, Text: "World"}, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 1) assert.Equal(t, "Hello World", result.Parts[0].Text) assert.Equal(t, map[string]any{"k1": "v1"}, result.Parts[0].Extra) }) t.Run("cross_chunk_audio_conflict_error", func(t *testing.T) { base64Data1 := "YXVkaW8x" base64Data2 := "YXVkaW8y" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ { Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Data1, MIMEType: "audio/wav", }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Data2, MIMEType: "audio/wav", }, }, }, }, }, } _, err := ConcatToolResults(chunks) assert.Error(t, err) assert.Contains(t, err.Error(), "conflicting") assert.Contains(t, err.Error(), "audio") }) t.Run("mixed_types_no_merge", func(t *testing.T) { imageURL := "https://example.com/image.png" videoURL := "https://example.com/video.mp4" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Text part"}, { Type: ToolPartTypeImage, Image: &ToolOutputImage{ MessagePartCommon: MessagePartCommon{ URL: &imageURL, }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeVideo, Video: &ToolOutputVideo{ MessagePartCommon: MessagePartCommon{ URL: &videoURL, }, }, }, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 3) assert.Equal(t, ToolPartTypeText, result.Parts[0].Type) assert.Equal(t, ToolPartTypeImage, result.Parts[1].Type) assert.Equal(t, ToolPartTypeVideo, result.Parts[2].Type) }) t.Run("mixed_text_and_single_audio", func(t *testing.T) { base64Data1 := "YXVkaW8x" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Part 1 "}, {Type: ToolPartTypeText, Text: "Part 2"}, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Data1, MIMEType: "audio/wav", }, }, }, }, }, { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: " Part 3"}, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 3) assert.Equal(t, ToolPartTypeText, result.Parts[0].Type) assert.Equal(t, "Part 1 Part 2", result.Parts[0].Text) assert.Equal(t, ToolPartTypeAudio, result.Parts[1].Type) assert.NotNil(t, result.Parts[1].Audio) assert.NotNil(t, result.Parts[1].Audio.Base64Data) assert.Equal(t, "YXVkaW8x", *result.Parts[1].Audio.Base64Data) assert.Equal(t, ToolPartTypeText, result.Parts[2].Type) assert.Equal(t, " Part 3", result.Parts[2].Text) }) t.Run("cross_chunk_audio_url_and_base64_conflict_error", func(t *testing.T) { audioURL := "https://example.com/audio.wav" base64Data := "YXVkaW8x" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ { Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ URL: &audioURL, MIMEType: "audio/wav", }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Data, MIMEType: "audio/wav", }, }, }, }, }, } _, err := ConcatToolResults(chunks) assert.Error(t, err) assert.Contains(t, err.Error(), "conflicting") assert.Contains(t, err.Error(), "audio") }) t.Run("single_audio_with_extra_fields", func(t *testing.T) { base64Data1 := "YXVkaW8x" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ { Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Data1, MIMEType: "audio/wav", Extra: map[string]any{ "key1": "value1", }, }, }, }, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 1) assert.Equal(t, ToolPartTypeAudio, result.Parts[0].Type) assert.NotNil(t, result.Parts[0].Audio) assert.NotNil(t, result.Parts[0].Audio.Base64Data) assert.Equal(t, "YXVkaW8x", *result.Parts[0].Audio.Base64Data) assert.NotNil(t, result.Parts[0].Audio.Extra) assert.Equal(t, "value1", result.Parts[0].Audio.Extra["key1"]) }) t.Run("cross_chunk_image_conflict_error", func(t *testing.T) { imageURL1 := "https://example.com/image1.png" imageURL2 := "https://example.com/image2.png" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ { Type: ToolPartTypeImage, Image: &ToolOutputImage{ MessagePartCommon: MessagePartCommon{ URL: &imageURL1, }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeImage, Image: &ToolOutputImage{ MessagePartCommon: MessagePartCommon{ URL: &imageURL2, }, }, }, }, }, } _, err := ConcatToolResults(chunks) assert.Error(t, err) assert.Contains(t, err.Error(), "conflicting") assert.Contains(t, err.Error(), "image") }) t.Run("cross_chunk_video_conflict_error", func(t *testing.T) { videoURL1 := "https://example.com/video1.mp4" videoURL2 := "https://example.com/video2.mp4" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ { Type: ToolPartTypeVideo, Video: &ToolOutputVideo{ MessagePartCommon: MessagePartCommon{ URL: &videoURL1, }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeVideo, Video: &ToolOutputVideo{ MessagePartCommon: MessagePartCommon{ URL: &videoURL2, }, }, }, }, }, } _, err := ConcatToolResults(chunks) assert.Error(t, err) assert.Contains(t, err.Error(), "conflicting") assert.Contains(t, err.Error(), "video") }) t.Run("cross_chunk_file_conflict_error", func(t *testing.T) { fileURL1 := "https://example.com/file1.pdf" fileURL2 := "https://example.com/file2.pdf" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ { Type: ToolPartTypeFile, File: &ToolOutputFile{ MessagePartCommon: MessagePartCommon{ URL: &fileURL1, }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeFile, File: &ToolOutputFile{ MessagePartCommon: MessagePartCommon{ URL: &fileURL2, }, }, }, }, }, } _, err := ConcatToolResults(chunks) assert.Error(t, err) assert.Contains(t, err.Error(), "conflicting") assert.Contains(t, err.Error(), "file") }) t.Run("cross_chunk_text_not_merged", func(t *testing.T) { chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Hello "}, }, }, { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "World"}, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 2) assert.Equal(t, ToolPartTypeText, result.Parts[0].Type) assert.Equal(t, "Hello ", result.Parts[0].Text) assert.Equal(t, ToolPartTypeText, result.Parts[1].Type) assert.Equal(t, "World", result.Parts[1].Text) }) t.Run("same_chunk_text_merged", func(t *testing.T) { chunks := []*ToolResult{ { Parts: []ToolOutputPart{ {Type: ToolPartTypeText, Text: "Hello "}, {Type: ToolPartTypeText, Text: "World"}, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 1) assert.Equal(t, ToolPartTypeText, result.Parts[0].Type) assert.Equal(t, "Hello World", result.Parts[0].Text) }) t.Run("different_non_text_types_across_chunks_allowed", func(t *testing.T) { imageURL := "https://example.com/image.png" videoURL := "https://example.com/video.mp4" base64Audio := "YXVkaW8x" chunks := []*ToolResult{ { Parts: []ToolOutputPart{ { Type: ToolPartTypeImage, Image: &ToolOutputImage{ MessagePartCommon: MessagePartCommon{ URL: &imageURL, }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeVideo, Video: &ToolOutputVideo{ MessagePartCommon: MessagePartCommon{ URL: &videoURL, }, }, }, }, }, { Parts: []ToolOutputPart{ { Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Audio, MIMEType: "audio/wav", }, }, }, }, }, } result, err := ConcatToolResults(chunks) assert.NoError(t, err) assert.Len(t, result.Parts, 3) assert.Equal(t, ToolPartTypeImage, result.Parts[0].Type) assert.Equal(t, ToolPartTypeVideo, result.Parts[1].Type) assert.Equal(t, ToolPartTypeAudio, result.Parts[2].Type) }) } func TestMessageString(t *testing.T) { t.Run("basic message", func(t *testing.T) { msg := &Message{ Role: User, Content: "Hello, world!", } result := msg.String() assert.Contains(t, result, "user: Hello, world!") }) t.Run("message with UserInputMultiContent", func(t *testing.T) { imageURL := "https://example.com/image.png" msg := &Message{ Role: User, Content: "", UserInputMultiContent: []MessageInputPart{ {Type: ChatMessagePartTypeText, Text: "Describe this image:"}, {Type: ChatMessagePartTypeImageURL, Image: &MessageInputImage{ MessagePartCommon: MessagePartCommon{URL: &imageURL}, }}, }, } result := msg.String() assert.Contains(t, result, "user_input_multi_content:") assert.Contains(t, result, "[0] text: Describe this image:") assert.Contains(t, result, "[1] image: url=https://example.com/image.png") }) t.Run("message with AssistantGenMultiContent", func(t *testing.T) { base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" msg := &Message{ Role: Assistant, Content: "", AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Here is the generated image:"}, {Type: ChatMessagePartTypeImageURL, Image: &MessageOutputImage{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Data, MIMEType: "image/png", }, }}, }, } result := msg.String() assert.Contains(t, result, "assistant_gen_multi_content:") assert.Contains(t, result, "[0] text: Here is the generated image:") assert.Contains(t, result, "[1] image: base64[") assert.Contains(t, result, "mime=image/png") }) t.Run("message with MultiContent (deprecated)", func(t *testing.T) { msg := &Message{ Role: User, Content: "", MultiContent: []ChatMessagePart{ {Type: ChatMessagePartTypeText, Text: "What is this?"}, {Type: ChatMessagePartTypeImageURL, ImageURL: &ChatMessageImageURL{URL: "https://example.com/photo.jpg"}}, }, } result := msg.String() assert.Contains(t, result, "multi_content:") assert.Contains(t, result, "[0] text: What is this?") assert.Contains(t, result, "[1] image_url: https://example.com/photo.jpg") }) t.Run("message with ToolCalls", func(t *testing.T) { idx := 0 msg := &Message{ Role: Assistant, Content: "", ToolCalls: []ToolCall{ { Index: &idx, ID: "call_123", Type: "function", Function: FunctionCall{ Name: "get_weather", Arguments: `{"location": "Beijing"}`, }, }, }, } result := msg.String() assert.Contains(t, result, "tool_calls:") assert.Contains(t, result, "index[0]:") assert.Contains(t, result, "get_weather") }) t.Run("tool message", func(t *testing.T) { msg := &Message{ Role: Tool, Content: `{"temperature": 25}`, ToolCallID: "call_123", ToolName: "get_weather", } result := msg.String() assert.Contains(t, result, "tool: {\"temperature\": 25}") assert.Contains(t, result, "tool_call_id: call_123") assert.Contains(t, result, "tool_call_name: get_weather") }) t.Run("message with reasoning content", func(t *testing.T) { msg := &Message{ Role: Assistant, Content: "The answer is 42.", ReasoningContent: "Let me think about this step by step...", } result := msg.String() assert.Contains(t, result, "reasoning content:") assert.Contains(t, result, "Let me think about this step by step...") }) t.Run("message with response meta", func(t *testing.T) { msg := &Message{ Role: Assistant, Content: "Hello!", ResponseMeta: &ResponseMeta{ FinishReason: "stop", Usage: &TokenUsage{ PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15, }, }, } result := msg.String() assert.Contains(t, result, "finish_reason: stop") assert.Contains(t, result, "usage:") }) t.Run("message with audio input", func(t *testing.T) { audioURL := "https://example.com/audio.wav" msg := &Message{ Role: User, UserInputMultiContent: []MessageInputPart{ {Type: ChatMessagePartTypeAudioURL, Audio: &MessageInputAudio{ MessagePartCommon: MessagePartCommon{URL: &audioURL}, }}, }, } result := msg.String() assert.Contains(t, result, "[0] audio: url=https://example.com/audio.wav") }) t.Run("message with video input", func(t *testing.T) { videoURL := "https://example.com/video.mp4" msg := &Message{ Role: User, UserInputMultiContent: []MessageInputPart{ {Type: ChatMessagePartTypeVideoURL, Video: &MessageInputVideo{ MessagePartCommon: MessagePartCommon{URL: &videoURL}, }}, }, } result := msg.String() assert.Contains(t, result, "[0] video: url=https://example.com/video.mp4") }) t.Run("message with file input", func(t *testing.T) { fileURL := "https://example.com/document.pdf" msg := &Message{ Role: User, UserInputMultiContent: []MessageInputPart{ {Type: ChatMessagePartTypeFileURL, File: &MessageInputFile{ MessagePartCommon: MessagePartCommon{URL: &fileURL}, }}, }, } result := msg.String() assert.Contains(t, result, "[0] file: url=https://example.com/document.pdf") }) t.Run("nil media parts", func(t *testing.T) { msg := &Message{ Role: User, UserInputMultiContent: []MessageInputPart{ {Type: ChatMessagePartTypeImageURL, Image: nil}, }, } result := msg.String() assert.Contains(t, result, "[0] image: ") }) t.Run("combined multi-content types", func(t *testing.T) { imageURL := "https://example.com/image.png" base64Audio := "YXVkaW9kYXRh" msg := &Message{ Role: User, Content: "Main content", UserInputMultiContent: []MessageInputPart{ {Type: ChatMessagePartTypeText, Text: "User input text"}, {Type: ChatMessagePartTypeImageURL, Image: &MessageInputImage{ MessagePartCommon: MessagePartCommon{URL: &imageURL}, }}, }, AssistantGenMultiContent: []MessageOutputPart{ {Type: ChatMessagePartTypeText, Text: "Assistant output text"}, {Type: ChatMessagePartTypeAudioURL, Audio: &MessageOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Audio, MIMEType: "audio/wav", }, }}, }, } result := msg.String() assert.Contains(t, result, "user: Main content") assert.Contains(t, result, "user_input_multi_content:") assert.Contains(t, result, "assistant_gen_multi_content:") }) } func TestConvToolOutputPartToMessageInputPart(t *testing.T) { t.Run("text part", func(t *testing.T) { toolPart := ToolOutputPart{ Type: ToolPartTypeText, Text: "test text", Extra: map[string]any{"key": "value"}, } result, err := convToolOutputPartToMessageInputPart(toolPart) assert.NoError(t, err) assert.Equal(t, ChatMessagePartTypeText, result.Type) assert.Equal(t, "test text", result.Text) assert.Equal(t, map[string]any{"key": "value"}, result.Extra) }) t.Run("image part", func(t *testing.T) { url := "https://example.com/image.png" toolPart := ToolOutputPart{ Type: ToolPartTypeImage, Image: &ToolOutputImage{ MessagePartCommon: MessagePartCommon{ URL: &url, MIMEType: "image/png", }, }, Extra: map[string]any{"img_key": "img_value"}, } result, err := convToolOutputPartToMessageInputPart(toolPart) assert.NoError(t, err) assert.Equal(t, ChatMessagePartTypeImageURL, result.Type) assert.NotNil(t, result.Image) assert.Equal(t, url, *result.Image.URL) assert.Equal(t, "image/png", result.Image.MIMEType) assert.Equal(t, map[string]any{"img_key": "img_value"}, result.Extra) }) t.Run("image part nil content", func(t *testing.T) { toolPart := ToolOutputPart{ Type: ToolPartTypeImage, Image: nil, } result, err := convToolOutputPartToMessageInputPart(toolPart) assert.Error(t, err) assert.ErrorContains(t, err, "image content is nil") assert.Equal(t, MessageInputPart{}, result) }) t.Run("audio part", func(t *testing.T) { base64Data := "dGVzdF9hdWRpbw==" toolPart := ToolOutputPart{ Type: ToolPartTypeAudio, Audio: &ToolOutputAudio{ MessagePartCommon: MessagePartCommon{ Base64Data: &base64Data, MIMEType: "audio/wav", }, }, } result, err := convToolOutputPartToMessageInputPart(toolPart) assert.NoError(t, err) assert.Equal(t, ChatMessagePartTypeAudioURL, result.Type) assert.NotNil(t, result.Audio) assert.Equal(t, base64Data, *result.Audio.Base64Data) assert.Equal(t, "audio/wav", result.Audio.MIMEType) }) t.Run("audio part nil content", func(t *testing.T) { toolPart := ToolOutputPart{ Type: ToolPartTypeAudio, Audio: nil, } _, err := convToolOutputPartToMessageInputPart(toolPart) assert.Error(t, err) assert.ErrorContains(t, err, "audio content is nil") }) t.Run("video part", func(t *testing.T) { url := "https://example.com/video.mp4" toolPart := ToolOutputPart{ Type: ToolPartTypeVideo, Video: &ToolOutputVideo{ MessagePartCommon: MessagePartCommon{ URL: &url, MIMEType: "video/mp4", }, }, } result, err := convToolOutputPartToMessageInputPart(toolPart) assert.NoError(t, err) assert.Equal(t, ChatMessagePartTypeVideoURL, result.Type) assert.NotNil(t, result.Video) assert.Equal(t, url, *result.Video.URL) assert.Equal(t, "video/mp4", result.Video.MIMEType) }) t.Run("video part nil content", func(t *testing.T) { toolPart := ToolOutputPart{ Type: ToolPartTypeVideo, Video: nil, } _, err := convToolOutputPartToMessageInputPart(toolPart) assert.Error(t, err) assert.ErrorContains(t, err, "video content is nil") }) t.Run("file part", func(t *testing.T) { url := "https://example.com/file.pdf" toolPart := ToolOutputPart{ Type: ToolPartTypeFile, File: &ToolOutputFile{ MessagePartCommon: MessagePartCommon{ URL: &url, MIMEType: "application/pdf", }, }, Extra: map[string]any{"file_key": "file_value"}, } result, err := convToolOutputPartToMessageInputPart(toolPart) assert.NoError(t, err) assert.Equal(t, ChatMessagePartTypeFileURL, result.Type) assert.NotNil(t, result.File) assert.Equal(t, url, *result.File.URL) assert.Equal(t, "application/pdf", result.File.MIMEType) assert.Equal(t, map[string]any{"file_key": "file_value"}, result.Extra) }) t.Run("file part nil content", func(t *testing.T) { toolPart := ToolOutputPart{ Type: ToolPartTypeFile, File: nil, } _, err := convToolOutputPartToMessageInputPart(toolPart) assert.Error(t, err) assert.ErrorContains(t, err, "file content is nil") }) t.Run("unknown type", func(t *testing.T) { toolPart := ToolOutputPart{ Type: "unknown_type", } _, err := convToolOutputPartToMessageInputPart(toolPart) assert.Error(t, err) assert.ErrorContains(t, err, "unknown tool part type") }) } ================================================ FILE: schema/select.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema const maxSelectNum = 5 func receiveN[T any](chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { return []func(chosenList []int, ss []*stream[T]) (index int, item *streamItem[T], ok bool){ nil, func(chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { item, ok := <-ss[chosenList[0]].items return chosenList[0], &item, ok }, func(chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { select { case item, ok := <-ss[chosenList[0]].items: return chosenList[0], &item, ok case item, ok := <-ss[chosenList[1]].items: return chosenList[1], &item, ok } }, func(chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { select { case item, ok := <-ss[chosenList[0]].items: return chosenList[0], &item, ok case item, ok := <-ss[chosenList[1]].items: return chosenList[1], &item, ok case item, ok := <-ss[chosenList[2]].items: return chosenList[2], &item, ok } }, func(chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { select { case item, ok := <-ss[chosenList[0]].items: return chosenList[0], &item, ok case item, ok := <-ss[chosenList[1]].items: return chosenList[1], &item, ok case item, ok := <-ss[chosenList[2]].items: return chosenList[2], &item, ok case item, ok := <-ss[chosenList[3]].items: return chosenList[3], &item, ok } }, func(chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { select { case item, ok := <-ss[chosenList[0]].items: return chosenList[0], &item, ok case item, ok := <-ss[chosenList[1]].items: return chosenList[1], &item, ok case item, ok := <-ss[chosenList[2]].items: return chosenList[2], &item, ok case item, ok := <-ss[chosenList[3]].items: return chosenList[3], &item, ok case item, ok := <-ss[chosenList[4]].items: return chosenList[4], &item, ok } }, }[len(chosenList)](chosenList, ss) } ================================================ FILE: schema/serialization.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "encoding/gob" "reflect" "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/serialization" ) func init() { RegisterName[Message]("_eino_message") RegisterName[[]*Message]("_eino_message_slice") RegisterName[Document]("_eino_document") RegisterName[RoleType]("_eino_role_type") RegisterName[ToolCall]("_eino_tool_call") RegisterName[FunctionCall]("_eino_function_call") RegisterName[ResponseMeta]("_eino_response_meta") RegisterName[TokenUsage]("_eino_token_usage") RegisterName[LogProbs]("_eino_log_probs") RegisterName[ChatMessagePart]("_eino_chat_message_part") RegisterName[ChatMessagePartType]("_eino_chat_message_type") RegisterName[ChatMessageImageURL]("_eino_chat_message_image_url") RegisterName[ChatMessageAudioURL]("_eino_chat_message_audio_url") RegisterName[ChatMessageVideoURL]("_eino_chat_message_video_url") RegisterName[ChatMessageFileURL]("_eino_chat_message_file_url") RegisterName[MessageInputPart]("_eino_message_input_part") RegisterName[MessageInputImage]("_eino_message_input_image") RegisterName[MessageInputAudio]("_eino_message_input_audio") RegisterName[MessageInputVideo]("_eino_message_input_video") RegisterName[MessageInputFile]("_eino_message_input_file") RegisterName[MessageOutputPart]("_eino_message_output_part") RegisterName[MessageOutputImage]("_eino_message_output_image") RegisterName[MessageOutputAudio]("_eino_message_output_audio") RegisterName[MessageOutputVideo]("_eino_message_output_video") RegisterName[MessagePartCommon]("_eino_message_part_common") RegisterName[ImageURLDetail]("_eino_image_url_detail") RegisterName[PromptTokenDetails]("_eino_prompt_token_details") } // RegisterName registers a type with a specific name for serialization. This is // required for any type you intend to persist in a graph or ADK checkpoint. // Use this function to maintain backward compatibility by mapping a type to a // previously used name. For new types, `Register` is preferred. // // It is recommended to call this in an `init()` function in the file where the // type is declared. // // What to Register: // - Top-level types used as state (e.g., structs). // - Concrete types that are assigned to interface fields. // // What NOT to Register: // - Struct fields with concrete types (e.g., `string`, `int`, other structs). // These are inferred via reflection. // // Serialization Rules: // // The serialization behavior is based on Go's standard `encoding/gob` package. // See https://pkg.go.dev/encoding/gob for detailed rules. // - Only exported struct fields are serialized. // - Functions and channels are not supported and will be ignored. // // This function panics if registration fails. func RegisterName[T any](name string) { gob.RegisterName(name, generic.NewInstance[T]()) err := serialization.GenericRegister[T](name) if err != nil { panic(err) } } func getTypeName(rt reflect.Type) string { name := rt.String() // But for named types (or pointers to them), qualify with import path. // Dereference one pointer looking for a named type. star := "" if rt.Name() == "" { if pt := rt; pt.Kind() == reflect.Pointer { star = "*" rt = pt.Elem() } } if rt.Name() != "" { if rt.PkgPath() == "" { name = star + rt.Name() } else { name = star + rt.PkgPath() + "." + rt.Name() } } return name } // Register registers a type for serialization. This is required for any type // you intend to persist in a graph or ADK checkpoint. It automatically determines // the type name and is the recommended method for registering new types. // // It is recommended to call this in an `init()` function in the file where the // type is declared. // // What to Register: // - Top-level types used as state (e.g., structs). // - Concrete types that are assigned to interface fields. // // What NOT to Register: // - Struct fields with concrete types (e.g., `string`, `int`, other structs). // These are inferred via reflection. // // Serialization Rules: // // The serialization behavior is based on Go's standard `encoding/gob` package. // See https://pkg.go.dev/encoding/gob for detailed rules. // - Only exported struct fields are serialized. // - Functions and channels are not supported and will be ignored. // // This function panics if registration fails. func Register[T any]() { value := generic.NewInstance[T]() gob.Register(value) name := getTypeName(reflect.TypeOf(value)) err := serialization.GenericRegister[T](name) if err != nil { panic(err) } } ================================================ FILE: schema/serialization_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "bytes" "encoding/gob" "fmt" "reflect" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/internal/serialization" ) type testStruct struct{} func TestGetTypeName(t *testing.T) { type localNamedType struct{} testCases := []struct { name string input reflect.Type expected string }{ { name: "named type from current package", input: reflect.TypeOf(testStruct{}), expected: "github.com/cloudwego/eino/schema.testStruct", }, { name: "pointer to named type from current package", input: reflect.TypeOf(&testStruct{}), expected: "*github.com/cloudwego/eino/schema.testStruct", }, { name: "unnamed map type", input: reflect.TypeOf(map[string]int{}), expected: "map[string]int", }, { name: "pointer to unnamed map type", input: reflect.TypeOf(new(map[string]int)), expected: "*map[string]int", }, { name: "built-in type", input: reflect.TypeOf(0), expected: "int", }, { name: "pointer to built-in type", input: reflect.TypeOf(new(int)), expected: "*int", }, { name: "named type from standard library", input: reflect.TypeOf(bytes.Buffer{}), expected: "bytes.Buffer", }, { name: "pointer to named type from standard library", input: reflect.TypeOf(&bytes.Buffer{}), expected: "*bytes.Buffer", }, { name: "local named type", input: reflect.TypeOf(localNamedType{}), expected: "github.com/cloudwego/eino/schema.localNamedType", }, { name: "pointer to local named type", input: reflect.TypeOf(&localNamedType{}), expected: "*github.com/cloudwego/eino/schema.localNamedType", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := getTypeName(tc.input) if actual != tc.expected { t.Errorf("getTypeName() got %q, want %q", actual, tc.expected) } }) } } func TestRegister(t *testing.T) { type testStruct1 struct { A any B any C any D any E any F any } type testStruct2 struct{} Register[*testStruct1]() Register[*testStruct2]() Register[[]Message]() Register[[]*testStruct2]() Register[[]testStruct2]() t1 := testStruct1{A: []*Message{{}}, B: []Message{{}}, C: []*testStruct2{{}}, D: []testStruct2{{}}, E: &testStruct1{}, F: []int{1}} in := &serialization.InternalSerializer{} mar, err := in.Marshal(t1) if err != nil { panic(err) } var t2 testStruct1 err = in.Unmarshal(mar, &t2) if err != nil { panic(err) } assert.Equal(t, t1, t2) buf := new(bytes.Buffer) err = gob.NewEncoder(buf).Encode(t1) if err != nil { panic(err) } err = gob.NewDecoder(buf).Decode(&t2) if err != nil { panic(err) } assert.Equal(t, t1, t2) f := func() (err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("panic: %v", r) } }() Register[[]int]() Register[map[string]any]() Register[[]*testStruct1]() Register[[]testStruct1]() return nil } err = f() assert.NoError(t, err) } // TestRegisterStructWithUUIDField reproduces issue #607 // uuid.UUID is a [16]byte array. Prior to the fix, calling schema.RegisterName on // a struct with a uuid.UUID field would panic during deserialization. func TestRegisterStructWithUUIDField(t *testing.T) { type Item struct { ID uuid.UUID } RegisterName[Item]("test_item") original := Item{ ID: uuid.MustParse("6ba7b810-9dad-11d1-80b4-00c04fd430c8"), } s := &serialization.InternalSerializer{} data, err := s.Marshal(original) assert.NoError(t, err) var result Item err = s.Unmarshal(data, &result) assert.NoError(t, err) assert.Equal(t, original.ID, result.ID) } ================================================ FILE: schema/stream.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "errors" "fmt" "io" "reflect" "runtime" "runtime/debug" "sync" "sync/atomic" "github.com/cloudwego/eino/internal/safe" ) // ErrNoValue is a sentinel returned from the convert function passed to // [StreamReaderWithConvert] to skip a stream element — the element is dropped // and the next one is read without surfacing an error to the caller. // // Use it to filter out empty or irrelevant chunks: // // outStream = schema.StreamReaderWithConvert(s, // func(src string) (string, error) { // if len(src) == 0 { // return "", schema.ErrNoValue // skip empty chunks // } // return src, nil // }) // // DO NOT use ErrNoValue in any other context. var ErrNoValue = errors.New("no value") // ErrRecvAfterClosed indicates that StreamReader.Recv was unexpectedly called after StreamReader.Close. // This error should not occur during normal use of StreamReader.Recv. If it does, please check your application code. var ErrRecvAfterClosed = errors.New("recv after stream closed") // SourceEOF represents an EOF error from a specific source stream. // It is only returned by the method Recv() of StreamReader created // with MergeNamedStreamReaders when one of its source streams reaches EOF. type SourceEOF struct { sourceName string } func (e *SourceEOF) Error() string { return fmt.Sprintf("EOF from source stream: %s", e.sourceName) } // GetSourceName extracts the source stream name from a SourceEOF error. // It returns the source name and a boolean indicating whether the error was a SourceEOF. // If the error is not a SourceEOF, it returns an empty string and false. func GetSourceName(err error) (string, bool) { var sErr *SourceEOF if errors.As(err, &sErr) { return sErr.sourceName, true } return "", false } // Pipe creates a new stream with the given capacity that represented with StreamWriter and StreamReader. // The capacity is the maximum number of items that can be buffered in the stream. // e.g. // // sr, sw := schema.Pipe[string](3) // go func() { // send data // defer sw.Close() // for i := 0; i < 10; i++ { // sw.Send(i, nil) // } // } // // defer sr.Close() // for chunk, err := sr.Recv() { // if errors.Is(err, io.EOF) { // break // } // fmt.Println(chunk) // } func Pipe[T any](cap int) (*StreamReader[T], *StreamWriter[T]) { stm := newStream[T](cap) return stm.asReader(), &StreamWriter[T]{stm: stm} } // StreamWriter the sender of a stream. // created by Pipe function. // eg. // // sr, sw := schema.Pipe[string](3) // go func() { // send data // defer sw.Close() // for i := 0; i < 10; i++ { // sw.Send(i, nil) // } // } type StreamWriter[T any] struct { stm *stream[T] } // Send sends a value to the stream. // e.g. // // closed := sw.Send(i, nil) // if closed { // // the stream is closed // } func (sw *StreamWriter[T]) Send(chunk T, err error) (closed bool) { return sw.stm.send(chunk, err) } // Close notify the receiver that the stream sender has finished. // The stream receiver will get an error of io.EOF from StreamReader.Recv(). // Notice: always remember to call Close() after sending all data. // eg. // // defer sw.Close() // for i := 0; i < 10; i++ { // sw.Send(i, nil) // } func (sw *StreamWriter[T]) Close() { sw.stm.closeSend() } // StreamReader is the consumer side of an Eino stream. // // A StreamReader is read-once: only one goroutine should call Recv, and the // reader must be closed exactly once (whether the loop finishes normally or // exits early via break or return). // // Typical usage: // // defer sr.Close() // always close, even after io.EOF // for { // chunk, err := sr.Recv() // if errors.Is(err, io.EOF) { // break // } // if err != nil { // return err // } // process(chunk) // } // // To fan-out a single stream to N independent consumers, call [StreamReader.Copy] // before any Recv; the original reader becomes unusable after the call. // // StreamReaders are created by [Pipe], [StreamReaderFromArray], // [MergeStreamReaders], [MergeNamedStreamReaders], and [StreamReaderWithConvert]. type StreamReader[T any] struct { typ readerType st *stream[T] ar *arrayReader[T] msr *multiStreamReader[T] srw *streamReaderWithConvert[T] csr *childStreamReader[T] } // Recv receives a value from the stream. // eg. // // for chunk, err := sr.Recv() { // if errors.Is(err, io.EOF) { // break // } // if err != nil { // fmt.Println(chunk) // } func (sr *StreamReader[T]) Recv() (T, error) { switch sr.typ { case readerTypeStream: return sr.st.recv() case readerTypeArray: return sr.ar.recv() case readerTypeMultiStream: return sr.msr.recv() case readerTypeWithConvert: return sr.srw.recv() case readerTypeChild: return sr.csr.recv() default: panic("impossible") } } // Close safely closes the StreamReader. // It should be called only once, as multiple calls may not work as expected. // Notice: always remember to call Close() after using Recv(). // e.g. // // defer sr.Close() // // for chunk, err := sr.Recv() { // if errors.Is(err, io.EOF) { // break // } // fmt.Println(chunk) // } func (sr *StreamReader[T]) Close() { switch sr.typ { case readerTypeStream: sr.st.closeRecv() case readerTypeArray: case readerTypeMultiStream: sr.msr.close() case readerTypeWithConvert: sr.srw.close() case readerTypeChild: sr.csr.close() default: panic("impossible") } } // Copy creates n independent StreamReaders that each receive every element of // the original stream. The original StreamReader becomes unusable after Copy. // // Use Copy when two or more pipeline branches need the same stream — // for example, when a stream must be fed to both a callback handler and the // next node in a graph: // // copies := sr.Copy(2) // sr1, sr2 := copies[0], copies[1] // defer sr1.Close() // defer sr2.Close() // // // sr1 and sr2 independently read the same elements // // n must be at least 1. If n < 2, the original reader is returned unchanged. func (sr *StreamReader[T]) Copy(n int) []*StreamReader[T] { if n < 2 { return []*StreamReader[T]{sr} } if sr.typ == readerTypeArray { ret := make([]*StreamReader[T], n) for i, ar := range sr.ar.copy(n) { ret[i] = &StreamReader[T]{typ: readerTypeArray, ar: ar} } return ret } return copyStreamReaders[T](sr, n) } // SetAutomaticClose sets the StreamReader to automatically close when it's no longer reachable and ready to be GCed. // NOT concurrency safe. func (sr *StreamReader[T]) SetAutomaticClose() { switch sr.typ { case readerTypeStream: if !sr.st.automaticClose { sr.st.automaticClose = true var flag uint32 sr.st.closedFlag = &flag runtime.SetFinalizer(sr, func(s *StreamReader[T]) { s.Close() }) } case readerTypeMultiStream: for _, s := range sr.msr.nonClosedStreams() { if !s.automaticClose { s.automaticClose = true var flag uint32 s.closedFlag = &flag runtime.SetFinalizer(s, func(st *stream[T]) { st.closeRecv() }) } } case readerTypeChild: parent := sr.csr.parent.sr parent.SetAutomaticClose() case readerTypeWithConvert: sr.srw.sr.SetAutomaticClose() case readerTypeArray: // no need to clean up default: } } func (sr *StreamReader[T]) recvAny() (any, error) { return sr.Recv() } func (sr *StreamReader[T]) copyAny(n int) []iStreamReader { ret := make([]iStreamReader, n) srs := sr.Copy(n) for i := 0; i < n; i++ { ret[i] = srs[i] } return ret } func arrToStream[T any](arr []T) *stream[T] { s := newStream[T](len(arr)) for i := range arr { s.send(arr[i], nil) } s.closeSend() return s } func (sr *StreamReader[T]) toStream() *stream[T] { switch sr.typ { case readerTypeStream: return sr.st case readerTypeArray: return sr.ar.toStream() case readerTypeMultiStream: return sr.msr.toStream() case readerTypeWithConvert: return sr.srw.toStream() case readerTypeChild: return sr.csr.toStream() default: panic("impossible") } } type readerType int const ( readerTypeStream readerType = iota readerTypeArray readerTypeMultiStream readerTypeWithConvert readerTypeChild ) type iStreamReader interface { recvAny() (any, error) copyAny(int) []iStreamReader Close() SetAutomaticClose() } // stream is a channel-based stream with 1 sender and 1 receiver. // The sender calls closeSend() to notify the receiver that the stream sender has finished. // The receiver calls closeRecv() to notify the sender that the receiver stop receiving. type stream[T any] struct { items chan streamItem[T] closed chan struct{} automaticClose bool closedFlag *uint32 // 0 = not closed, 1 = closed, only used when automaticClose is set } type streamItem[T any] struct { chunk T err error } func newStream[T any](cap int) *stream[T] { return &stream[T]{ items: make(chan streamItem[T], cap), closed: make(chan struct{}), } } func (s *stream[T]) asReader() *StreamReader[T] { return &StreamReader[T]{typ: readerTypeStream, st: s} } func (s *stream[T]) recv() (chunk T, err error) { item, ok := <-s.items if !ok { item.err = io.EOF } return item.chunk, item.err } func (s *stream[T]) send(chunk T, err error) (closed bool) { // if the stream is closed, return immediately select { case <-s.closed: return true default: } item := streamItem[T]{chunk, err} select { case <-s.closed: return true case s.items <- item: return false } } func (s *stream[T]) closeSend() { close(s.items) } func (s *stream[T]) closeRecv() { if s.automaticClose { if atomic.CompareAndSwapUint32(s.closedFlag, 0, 1) { close(s.closed) } return } close(s.closed) } // StreamReaderFromArray creates a StreamReader from a given slice of elements. // It takes an array of type T and returns a pointer to a StreamReader[T]. // This allows for streaming the elements of the array in a controlled manner. // eg. // // sr := schema.StreamReaderFromArray([]int{1, 2, 3}) // defer sr.Close() // // for chunk, err := sr.Recv() { // fmt.Println(chunk) // } func StreamReaderFromArray[T any](arr []T) *StreamReader[T] { return &StreamReader[T]{ar: &arrayReader[T]{arr: arr}, typ: readerTypeArray} } type arrayReader[T any] struct { arr []T index int } func (ar *arrayReader[T]) recv() (T, error) { if ar.index < len(ar.arr) { ret := ar.arr[ar.index] ar.index++ return ret, nil } var t T return t, io.EOF } func (ar *arrayReader[T]) copy(n int) []*arrayReader[T] { ret := make([]*arrayReader[T], n) for i := 0; i < n; i++ { ret[i] = &arrayReader[T]{ arr: ar.arr, index: ar.index, } } return ret } func (ar *arrayReader[T]) toStream() *stream[T] { return arrToStream(ar.arr[ar.index:]) } type multiArrayReader[T any] struct { ars []*arrayReader[T] index int } type multiStreamReader[T any] struct { sts []*stream[T] itemsCases []reflect.SelectCase nonClosed []int sourceReaderNames []string } func newMultiStreamReader[T any](sts []*stream[T]) *multiStreamReader[T] { var itemsCases []reflect.SelectCase if len(sts) > maxSelectNum { itemsCases = make([]reflect.SelectCase, len(sts)) for i, st := range sts { itemsCases[i] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(st.items), } } } nonClosed := make([]int, len(sts)) for i := range sts { nonClosed[i] = i } return &multiStreamReader[T]{ sts: sts, itemsCases: itemsCases, nonClosed: nonClosed, } } func (msr *multiStreamReader[T]) recv() (T, error) { for len(msr.nonClosed) > 0 { var chosen int var ok bool if len(msr.nonClosed) > maxSelectNum { var recv reflect.Value chosen, recv, ok = reflect.Select(msr.itemsCases) if ok { item := recv.Interface().(streamItem[T]) return item.chunk, item.err } msr.itemsCases[chosen].Chan = reflect.Value{} } else { var item *streamItem[T] chosen, item, ok = receiveN(msr.nonClosed, msr.sts) if ok { return item.chunk, item.err } } // delete the closed stream for i := range msr.nonClosed { if msr.nonClosed[i] == chosen { msr.nonClosed = append(msr.nonClosed[:i], msr.nonClosed[i+1:]...) break } } if len(msr.sourceReaderNames) > 0 { var t T return t, &SourceEOF{msr.sourceReaderNames[chosen]} } } var t T return t, io.EOF } func (msr *multiStreamReader[T]) nonClosedStreams() []*stream[T] { ret := make([]*stream[T], len(msr.nonClosed)) for i, idx := range msr.nonClosed { ret[i] = msr.sts[idx] } return ret } func (msr *multiStreamReader[T]) close() { for _, s := range msr.sts { s.closeRecv() } } func (msr *multiStreamReader[T]) toStream() *stream[T] { return toStream[T, *multiStreamReader[T]](msr) } type streamReaderWithConvert[T any] struct { sr iStreamReader convert func(any) (T, error) errWrapper func(error) error } func newStreamReaderWithConvert[T any](origin iStreamReader, convert func(any) (T, error), opts ...ConvertOption) *StreamReader[T] { opt := &convertOptions{} for _, o := range opts { o(opt) } srw := &streamReaderWithConvert[T]{ sr: origin, convert: convert, errWrapper: opt.ErrWrapper, } return &StreamReader[T]{ typ: readerTypeWithConvert, srw: srw, } } type convertOptions struct { ErrWrapper func(error) error } type ConvertOption func(*convertOptions) // WithErrWrapper wraps the first error encountered in a stream reader during conversion by StreamReaderWithConvert. // The error returned by the convert function will not be wrapped. // If the returned err is nil or is ErrNoValue, the stream chunk will be ignored func WithErrWrapper(wrapper func(error) error) ConvertOption { return func(o *convertOptions) { o.ErrWrapper = wrapper } } // StreamReaderWithConvert returns a new StreamReader[D] that wraps sr and // applies convert to every element. The original reader sr must not be used // after calling this function. // // Filtering: if convert returns [ErrNoValue], the element is silently dropped // and the next element is read. This lets you strip empty or irrelevant chunks // without surfacing an error to the caller. // // Error wrapping: use [WithErrWrapper] to wrap non-convert errors (e.g. those // arriving from an upstream source) before they reach the caller. // // intReader := schema.StreamReaderFromArray([]int{0, 1, 2, 3}) // strReader := schema.StreamReaderWithConvert(intReader, // func(i int) (string, error) { // if i == 0 { // return "", schema.ErrNoValue // skip zero // } // return fmt.Sprintf("val_%d", i), nil // }) // defer strReader.Close() // // Recv yields "val_1", "val_2", "val_3" func StreamReaderWithConvert[T, D any](sr *StreamReader[T], convert func(T) (D, error), opts ...ConvertOption) *StreamReader[D] { c := func(a any) (D, error) { return convert(a.(T)) } return newStreamReaderWithConvert(sr, c, opts...) } func (srw *streamReaderWithConvert[T]) recv() (T, error) { for { out, err := srw.sr.recvAny() if err != nil { var t T if err == io.EOF { return t, err } if srw.errWrapper != nil { err = srw.errWrapper(err) if err != nil && !errors.Is(err, ErrNoValue) { return t, err } } return t, err } t, err := srw.convert(out) if err == nil { return t, nil } if !errors.Is(err, ErrNoValue) { return t, err } } } func (srw *streamReaderWithConvert[T]) close() { srw.sr.Close() } type reader[T any] interface { recv() (T, error) close() } func toStream[T any, Reader reader[T]](r Reader) *stream[T] { ret := newStream[T](5) go func() { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) var chunk T _ = ret.send(chunk, e) } ret.closeSend() r.close() }() for { out, err := r.recv() if err == io.EOF { break } closed := ret.send(out, err) if closed { break } } }() return ret } func (srw *streamReaderWithConvert[T]) toStream() *stream[T] { return toStream[T, *streamReaderWithConvert[T]](srw) } type cpStreamElement[T any] struct { once sync.Once next *cpStreamElement[T] item streamItem[T] } // copyStreamReaders creates multiple independent StreamReaders from a single StreamReader. // Each child StreamReader can read from the original stream independently. func copyStreamReaders[T any](sr *StreamReader[T], n int) []*StreamReader[T] { cpsr := &parentStreamReader[T]{ sr: sr, subStreamList: make([]*cpStreamElement[T], n), closedNum: 0, } // Initialize subStreamList with an empty element, which acts like a tail node. // A nil element (used for dereference) represents that the child has been closed. // It is challenging to link the previous and current elements when the length of the original channel is unknown. // Additionally, using a previous pointer complicates dereferencing elements, possibly requiring reference counting. elem := &cpStreamElement[T]{} for i := range cpsr.subStreamList { cpsr.subStreamList[i] = elem } ret := make([]*StreamReader[T], n) for i := range ret { ret[i] = &StreamReader[T]{ csr: &childStreamReader[T]{ parent: cpsr, index: i, }, typ: readerTypeChild, } } return ret } type parentStreamReader[T any] struct { // sr is the original StreamReader. sr *StreamReader[T] // subStreamList maps each child's index to its latest read chunk. // Each value comes from a hidden linked list of cpStreamElement. subStreamList []*cpStreamElement[T] // closedNum is the count of closed children. closedNum uint32 } // peek is not safe for concurrent use with the same idx but is safe for different idx. // Ensure that each child StreamReader uses a for-loop in a single goroutine. func (p *parentStreamReader[T]) peek(idx int) (t T, err error) { elem := p.subStreamList[idx] if elem == nil { // Unexpected call to receive after the child has been closed. return t, ErrRecvAfterClosed } // The sync.Once here is used to: // 1. Write the content of this cpStreamElement. // 2. Initialize the 'next' field of this cpStreamElement with an empty cpStreamElement, // similar to the initialization in copyStreamReaders. elem.once.Do(func() { t, err = p.sr.Recv() elem.item = streamItem[T]{chunk: t, err: err} if err != io.EOF { elem.next = &cpStreamElement[T]{} p.subStreamList[idx] = elem.next } }) // The element has been set and will not be modified again. // Therefore, children can read this element's content and 'next' pointer concurrently. t = elem.item.chunk err = elem.item.err if err != io.EOF { p.subStreamList[idx] = elem.next } return t, err } func (p *parentStreamReader[T]) close(idx int) { if p.subStreamList[idx] == nil { return // avoid close multiple times } p.subStreamList[idx] = nil curClosedNum := atomic.AddUint32(&p.closedNum, 1) allClosed := int(curClosedNum) == len(p.subStreamList) if allClosed { p.sr.Close() } } type childStreamReader[T any] struct { parent *parentStreamReader[T] index int } func (csr *childStreamReader[T]) recv() (T, error) { return csr.parent.peek(csr.index) } func (csr *childStreamReader[T]) toStream() *stream[T] { return toStream[T, *childStreamReader[T]](csr) } func (csr *childStreamReader[T]) close() { csr.parent.close(csr.index) } // MergeStreamReaders fans in multiple StreamReaders into a single StreamReader. // Elements from all source streams are interleaved in arrival order (non-deterministic). // The merged reader reaches EOF only after every source stream has been exhausted. // // Callers must still close the merged reader; it propagates the close signal // to all underlying sources. // // Use [MergeNamedStreamReaders] instead when you need to know which source // stream ended first (it emits a [SourceEOF] per-source EOF rather than // silently discarding them). // // Returns nil if srs is empty. func MergeStreamReaders[T any](srs []*StreamReader[T]) *StreamReader[T] { if len(srs) < 1 { return nil } if len(srs) < 2 { return srs[0] } var arr []T var ss []*stream[T] for _, sr := range srs { switch sr.typ { case readerTypeStream: ss = append(ss, sr.st) case readerTypeArray: arr = append(arr, sr.ar.arr[sr.ar.index:]...) case readerTypeMultiStream: ss = append(ss, sr.msr.nonClosedStreams()...) case readerTypeWithConvert: ss = append(ss, sr.srw.toStream()) case readerTypeChild: ss = append(ss, sr.csr.toStream()) default: panic("impossible") } } if len(ss) == 0 { return &StreamReader[T]{ typ: readerTypeArray, ar: &arrayReader[T]{ arr: arr, index: 0, }, } } if len(arr) != 0 { s := arrToStream(arr) ss = append(ss, s) } return &StreamReader[T]{ typ: readerTypeMultiStream, msr: newMultiStreamReader(ss), } } // MergeNamedStreamReaders merges multiple named StreamReaders into one. // Unlike [MergeStreamReaders], when a source stream reaches EOF the merged // reader emits a [SourceEOF] error (instead of silently continuing) so you can // detect exactly which source finished. Use [GetSourceName] to retrieve the // name from a SourceEOF error. The merged reader itself signals io.EOF only // after all named sources are exhausted. // // This is useful when downstream logic must react differently to each source // completing — for example, draining one agent's output before proceeding: // // namedStreams := map[string]*schema.StreamReader[string]{ // "agent_a": srA, // "agent_b": srB, // } // merged := schema.MergeNamedStreamReaders(namedStreams) // defer merged.Close() // for { // chunk, err := merged.Recv() // if errors.Is(err, io.EOF) { break } // if name, ok := schema.GetSourceName(err); ok { // fmt.Printf("%s finished\n", name) // continue // } // if err != nil { return err } // process(chunk) // } // // Returns nil if srs is empty. func MergeNamedStreamReaders[T any](srs map[string]*StreamReader[T]) *StreamReader[T] { if len(srs) < 1 { return nil } ss := make([]*StreamReader[T], len(srs)) names := make([]string, len(srs)) i := 0 for name, sr := range srs { ss[i] = sr names[i] = name i++ } return InternalMergeNamedStreamReaders(ss, names) } // InternalMergeNamedStreamReaders merges multiple readers with their names // into a single multi-stream reader. func InternalMergeNamedStreamReaders[T any](srs []*StreamReader[T], names []string) *StreamReader[T] { ss := make([]*stream[T], len(srs)) for i, sr := range srs { ss[i] = sr.toStream() } msr := newMultiStreamReader(ss) msr.sourceReaderNames = names return &StreamReader[T]{ typ: readerTypeMultiStream, msr: msr, } } ================================================ FILE: schema/stream_copy_external_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "fmt" "io" "runtime" "sort" "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestStream1(t *testing.T) { runtime.GOMAXPROCS(1) sr, sw := Pipe[int](0) go func() { for i := 0; i < 100; i++ { sw.Send(i, nil) time.Sleep(3 * time.Millisecond) } sw.Close() }() copied := sr.Copy(2) var ( now = time.Now().UnixMilli() ts = []int64{now, now} tsOld = []int64{now, now} ) var count int32 wg := sync.WaitGroup{} wg.Add(2) go func() { i := 0 s := copied[0] for { n, e := s.Recv() if e != nil { if e == io.EOF { break } } tsOld[0] = ts[0] ts[0] = time.Now().UnixMilli() interval := ts[0] - tsOld[0] if interval >= 6 { atomic.AddInt32(&count, 1) } assert.Equal(t, i, n) i++ } wg.Done() }() go func() { i := 0 s := copied[1] for { n, e := s.Recv() if e != nil { if e == io.EOF { break } } tsOld[1] = ts[1] ts[1] = time.Now().UnixMilli() interval := ts[1] - tsOld[1] if interval >= 6 { atomic.AddInt32(&count, 1) } assert.Equal(t, i, n) i++ } wg.Done() }() wg.Wait() t.Logf("count= %d", count) } type info struct { idx int ts int64 after int64 content string } func TestCopyDelay(t *testing.T) { runtime.GOMAXPROCS(10) n := 3 //m := 100 s := newStream[string](0) scp := s.asReader().Copy(n) go func() { s.send("1", nil) s.send("2", nil) time.Sleep(time.Second) s.send("3", nil) s.closeSend() }() wg := sync.WaitGroup{} wg.Add(n) infoList := make([][]info, n) for i := 0; i < n; i++ { j := i go func() { defer func() { scp[j].Close() wg.Done() }() for { lastTime := time.Now() str, err := scp[j].Recv() if err == io.EOF { break } now := time.Now() infoList[j] = append(infoList[j], info{ idx: j, ts: now.UnixMicro(), after: now.Sub(lastTime).Milliseconds(), content: str, }) } }() } wg.Wait() infos := make([]info, 0) for _, infoL := range infoList { infos = append(infos, infoL...) } sort.Slice(infos, func(i, j int) bool { return infos[i].ts < infos[j].ts }) for _, info := range infos { fmt.Printf("child[%d] ts[%d] after[%5dms] content[%s]\n", info.idx, info.ts, info.after, info.content) } } ================================================ FILE: schema/stream_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "errors" "fmt" "io" "math/rand" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) func TestStream(t *testing.T) { s := newStream[int](0) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for i := 0; i < 10; i++ { closed := s.send(i, nil) if closed { break } } s.closeSend() }() i := 0 for { i++ if i > 5 { s.closeRecv() break } v, err := s.recv() if err != nil { assert.ErrorIs(t, err, io.EOF) break } t.Log(v) } wg.Wait() } func TestStreamCopy(t *testing.T) { s := newStream[string](10) srs := s.asReader().Copy(2) s.send("a", nil) s.send("b", nil) s.send("c", nil) s.closeSend() defer func() { for _, sr := range srs { sr.Close() } }() for { v, err := srs[0].Recv() if errors.Is(err, io.EOF) { break } if err != nil { t.Fatal(err) } t.Log("copy 01 recv", v) } for { v, err := srs[1].Recv() if errors.Is(err, io.EOF) { break } if err != nil { t.Fatal(err) } t.Log("copy 02 recv", v) } for { v, err := s.recv() if errors.Is(err, io.EOF) { break } if err != nil { t.Fatal(err) } t.Log("recv origin", v) } t.Log("done") } func TestNewStreamCopy(t *testing.T) { t.Run("test one index recv channel blocked while other indexes could recv", func(t *testing.T) { s := newStream[string](1) scp := s.asReader().Copy(2) var t1, t2 time.Time go func() { s.send("a", nil) t1 = time.Now() time.Sleep(time.Millisecond * 200) s.send("a", nil) s.closeSend() }() wg := sync.WaitGroup{} wg.Add(2) go func() { defer func() { scp[0].Close() wg.Done() }() for { str, err := scp[0].Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, str, "a") } }() go func() { defer func() { scp[1].Close() wg.Done() }() time.Sleep(time.Millisecond * 100) for { str, err := scp[1].Recv() if err == io.EOF { break } if t2.IsZero() { t2 = time.Now() } assert.NoError(t, err) assert.Equal(t, str, "a") } }() wg.Wait() assert.True(t, t2.Sub(t1) < time.Millisecond*200) }) t.Run("test one index recv channel blocked and other index closed", func(t *testing.T) { s := newStream[string](1) scp := s.asReader().Copy(2) go func() { s.send("a", nil) time.Sleep(time.Millisecond * 200) s.send("a", nil) s.closeSend() }() wg := sync.WaitGroup{} wg.Add(2) //buf := scp[0].csr.parent.mem.buf go func() { defer func() { scp[0].Close() wg.Done() }() for { str, err := scp[0].Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, str, "a") } }() go func() { time.Sleep(time.Millisecond * 100) scp[1].Close() scp[1].Close() // try close multiple times wg.Done() }() wg.Wait() //assert.Equal(t, 0, buf.Len()) }) t.Run("test long time recv", func(t *testing.T) { s := newStream[int](2) n := 1000 go func() { for i := 0; i < n; i++ { s.send(i, nil) } s.closeSend() }() m := 100 wg := sync.WaitGroup{} wg.Add(m) copies := s.asReader().Copy(m) for i := 0; i < m; i++ { idx := i go func() { cp := copies[idx] l := 0 defer func() { assert.Equal(t, 1000, l) cp.Close() wg.Done() }() for { exp, err := cp.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, exp, l) l++ } }() } wg.Wait() //memo := copies[0].csr.parent.mem //assert.Equal(t, true, memo.hasFinished) //assert.Equal(t, 0, memo.buf.Len()) }) t.Run("test closes", func(t *testing.T) { s := newStream[int](20) n := 1000 go func() { for i := 0; i < n; i++ { s.send(i, nil) } s.closeSend() }() m := 100 wg := sync.WaitGroup{} wg.Add(m) wgEven := sync.WaitGroup{} wgEven.Add(m / 2) sr := s.asReader() sr.SetAutomaticClose() copies := sr.Copy(m) for i := 0; i < m; i++ { idx := i go func() { cp := copies[idx] l := 0 defer func() { cp.Close() wg.Done() if idx%2 == 0 { wgEven.Done() } }() for { if idx%2 == 0 && l == idx { break } exp, err := cp.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, exp, l) l++ } }() } wgEven.Wait() wg.Wait() assert.Equal(t, m, int(copies[0].csr.parent.closedNum)) }) t.Run("test reader do no close", func(t *testing.T) { s := newStream[int](20) n := 1000 go func() { for i := 0; i < n; i++ { s.send(i, nil) } s.closeSend() }() m := 4 wg := sync.WaitGroup{} wg.Add(m) copies := s.asReader().Copy(m) for i := 0; i < m; i++ { idx := i cp := copies[idx] cp.SetAutomaticClose() go func() { l := 0 defer func() { wg.Done() }() for { exp, err := cp.Recv() if err == io.EOF { break } assert.NoError(t, err) assert.Equal(t, exp, l) l++ } }() } wg.Wait() assert.Equal(t, 0, int(copies[0].csr.parent.closedNum)) // not closed }) } func checkStream(s *StreamReader[int]) error { defer s.Close() for i := 0; i < 10; i++ { chunk, err := s.Recv() if err != nil { return err } if chunk != i { return fmt.Errorf("receive err, expected:%d, actual: %d", i, chunk) } } _, err := s.Recv() if err != io.EOF { return fmt.Errorf("close chan fail") } return nil } func testStreamN(cap, n int) error { s := newStream[int](cap) go func() { for i := 0; i < 10; i++ { s.send(i, nil) } s.closeSend() }() vs := s.asReader().Copy(n) err := checkStream(vs[0]) if err != nil { return err } vs = vs[1].Copy(n) err = checkStream(vs[0]) if err != nil { return err } vs = vs[1].Copy(n) err = checkStream(vs[0]) if err != nil { return err } return nil } func TestCopy(t *testing.T) { for i := 0; i < 10; i++ { for j := 2; j < 10; j++ { err := testStreamN(i, j) if err != nil { t.Fatal(err) } } } } func TestCopy5(t *testing.T) { s := newStream[int](0) go func() { for i := 0; i < 10; i++ { closed := s.send(i, nil) if closed { fmt.Printf("has closed") } } s.closeSend() }() vs := s.asReader().Copy(5) time.Sleep(time.Second) defer func() { for _, v := range vs { v.Close() } }() for i := 0; i < 10; i++ { chunk, err := vs[0].Recv() if err != nil { t.Fatal(err) } if chunk != i { t.Fatalf("receive err, expected:%d, actual: %d", i, chunk) } } _, err := vs[0].Recv() if err != io.EOF { t.Fatalf("copied stream reader cannot return EOF") } _, err = vs[0].Recv() if err != io.EOF { t.Fatalf("copied stream reader cannot return EOF repeatedly") } } func TestStreamReaderWithConvert(t *testing.T) { s := newStream[int](2) var cntA int var e error convA := func(src int) (int, error) { if src == 1 { return 0, fmt.Errorf("mock err") } return src, nil } sta := StreamReaderWithConvert[int, int](s.asReader(), convA) sta.SetAutomaticClose() s.send(1, nil) s.send(2, nil) s.closeSend() for { item, err := sta.Recv() if err != nil { if err == io.EOF { break } e = err continue } cntA += item } assert.NotNil(t, e) assert.Equal(t, cntA, 2) } func TestArrayStreamCombined(t *testing.T) { asr := &StreamReader[int]{ typ: readerTypeArray, ar: &arrayReader[int]{ arr: []int{0, 1, 2}, index: 0, }, } s := newStream[int](3) for i := 3; i < 6; i++ { s.send(i, nil) } s.closeSend() nSR := MergeStreamReaders([]*StreamReader[int]{asr, s.asReader()}) nSR.SetAutomaticClose() record := make([]bool, 6) for i := 0; i < 6; i++ { chunk, err := nSR.Recv() if err != nil { t.Fatal(err) } if record[chunk] { t.Fatal("record duplicated") } record[chunk] = true } _, err := nSR.Recv() if err != io.EOF { t.Fatal("reader haven't finish correctly") } for i := range record { if !record[i] { t.Fatal("record missing") } } } func TestMultiStream(t *testing.T) { var sts []*stream[int] sum := 0 for i := 0; i < 10; i++ { size := rand.Intn(10) + 1 sum += size st := newStream[int](size) for j := 1; j <= size; j++ { st.send(j&0xffff+i<<16, nil) } st.closeSend() sts = append(sts, st) } mst := newMultiStreamReader(sts) receiveList := make([]int, 10) for i := 0; i < sum; i++ { chunk, err := mst.recv() if err != nil { t.Fatal(err) } if receiveList[chunk>>16] >= chunk&0xffff { t.Fatal("out of order") } receiveList[chunk>>16] = chunk & 0xffff } _, err := mst.recv() if err != io.EOF { t.Fatal("end stream haven't return EOF") } } // TestMergeNamedStreamReaders tests the functionality of MergeNamedStreamReaders // with a focus on SourceEOF error handling. func TestMergeNamedStreamReaders(t *testing.T) { t.Run("BasicSourceEOF", func(t *testing.T) { // Create two named streams sr1, sw1 := Pipe[string](2) sr2, sw2 := Pipe[string](2) // Merge the streams with names namedStreams := map[string]*StreamReader[string]{ "stream1": sr1, "stream2": sr2, } mergedSR := MergeNamedStreamReaders(namedStreams) mergedSR.SetAutomaticClose() // Send data to the first stream and close it immediately go func() { defer sw1.Close() sw1.Send("data1-1", nil) sw1.Send("data1-2", nil) // First stream ends }() // Send data to the second stream with a delay before closing go func() { defer sw2.Close() sw2.Send("data2-1", nil) sw2.Send("data2-2", nil) sw2.Send("data2-3", nil) // Second stream ends }() // Track received data and EOF sources receivedData := make(map[string][]string) eofSources := make([]string, 0, 2) for { chunk, err := mergedSR.Recv() if err != nil { // Check if it's a SourceEOF error if sourceName, ok := GetSourceName(err); ok { eofSources = append(eofSources, sourceName) t.Logf("Received EOF from source: %s", sourceName) continue // Continue receiving from other streams } // If it's a regular EOF, all streams have ended if errors.Is(err, io.EOF) { break } // Handle other errors t.Errorf("Error receiving data: %v", err) break } // Categorize data by prefix if len(chunk) >= 5 { prefix := chunk[:5] if prefix == "data1" { receivedData["stream1"] = append(receivedData["stream1"], chunk) } else if prefix == "data2" { receivedData["stream2"] = append(receivedData["stream2"], chunk) } } } // Verify we received both SourceEOF errors if len(eofSources) != 2 { t.Errorf("Expected 2 SourceEOF errors, got %d", len(eofSources)) } // Verify the source names are correct expectedSources := map[string]bool{"stream1": false, "stream2": false} for _, source := range eofSources { if _, exists := expectedSources[source]; !exists { t.Errorf("Unexpected source name: %s", source) } else { expectedSources[source] = true } } // Verify all expected sources were seen for source, seen := range expectedSources { if !seen { t.Errorf("Did not receive SourceEOF for %s", source) } } // Verify we received all expected data if len(receivedData["stream1"]) != 2 { t.Errorf("Expected 2 items from stream1, got %d", len(receivedData["stream1"])) } if len(receivedData["stream2"]) != 3 { t.Errorf("Expected 3 items from stream2, got %d", len(receivedData["stream2"])) } }) t.Run("EmptyStream", func(t *testing.T) { // Create two streams, one will be empty sr1, sw1 := Pipe[string](2) sr2, sw2 := Pipe[string](2) // Close the first stream immediately to make it empty sw1.Close() // Merge the streams with names namedStreams := map[string]*StreamReader[string]{ "empty": sr1, "data": sr2, } mergedSR := MergeNamedStreamReaders(namedStreams) mergedSR.SetAutomaticClose() // Send data to the second stream go func() { defer sw2.Close() sw2.Send("test-data", nil) }() // Track received EOFs and data eofSources := make(map[string]bool, 2) receivedData := make([]string, 0, 1) for { chunk, err := mergedSR.Recv() if err != nil { if sourceName, ok := GetSourceName(err); ok { eofSources[sourceName] = true continue } if errors.Is(err, io.EOF) { break } t.Errorf("Error receiving data: %v", err) break } receivedData = append(receivedData, chunk) } // Verify we received EOF from the empty stream if len(eofSources) != 2 { t.Errorf("Expected 2 SourceEOF errors, got %d", len(eofSources)) } if _, exist := eofSources["empty"]; !exist { t.Errorf("Expected EOF from 'empty' stream, got '%v'", eofSources) } if _, exist := eofSources["data"]; !exist { t.Errorf("Expected EOF from 'data' stream, got '%v'", eofSources) } // Verify we received the data from the non-empty stream if len(receivedData) != 1 || receivedData[0] != "test-data" { t.Errorf("Expected to receive 'test-data', got %v", receivedData) } }) t.Run("ArraySource", func(t *testing.T) { // Create three named streams sr1, sw1 := Pipe[string](2) sr2, sw2 := Pipe[string](2) sr3 := StreamReaderFromArray([]string{"data3-1", "data3-2", "data3-3"}) // Merge the streams with names namedStreams := map[string]*StreamReader[string]{ "stream1": sr1, "stream2": sr2, "stream3": sr3, } mergedSR := MergeNamedStreamReaders(namedStreams) mergedSR.SetAutomaticClose() // Send data and close streams in sequence go func() { // First stream sends one item then closes sw1.Send("data1", nil) sw1.Close() // Second stream sends two items then closes sw2.Send("data2-1", nil) sw2.Send("data2-2", nil) sw2.Close() }() // Track EOF order and data count eofOrder := make([]string, 0, 3) dataCount := 0 for { _, err := mergedSR.Recv() if err != nil { if sourceName, ok := GetSourceName(err); ok { eofOrder = append(eofOrder, sourceName) continue } if errors.Is(err, io.EOF) { break } t.Errorf("Error receiving data: %v", err) break } dataCount++ } // Verify EOF count if len(eofOrder) != 3 { t.Errorf("Expected 3 SourceEOF errors, got %d", len(eofOrder)) } // Verify data count if dataCount != 6 { t.Errorf("Expected 6 data items, got %d", dataCount) } }) t.Run("ErrorPropagation", func(t *testing.T) { // Create two streams sr1, sw1 := Pipe[string](2) sr2, sw2 := Pipe[string](2) // Merge the streams with names namedStreams := map[string]*StreamReader[string]{ "normal": sr1, "error": sr2, } mergedSR := MergeNamedStreamReaders(namedStreams) defer mergedSR.Close() testError := errors.New("test error") // Send normal data to first stream go func() { defer sw1.Close() sw1.Send("normal-data", nil) }() // Send error to second stream go func() { defer sw2.Close() sw2.Send("", testError) }() // Track received errors var receivedError error for { _, err := mergedSR.Recv() if err != nil { // Skip SourceEOF errors if _, ok := GetSourceName(err); ok { continue } if errors.Is(err, io.EOF) { break } // Store the first non-EOF error receivedError = err break } } // Verify we received the test error if receivedError == nil || receivedError.Error() != testError.Error() { t.Errorf("Expected error '%v', got '%v'", testError, receivedError) } }) } ================================================ FILE: schema/tool.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "sort" "github.com/eino-contrib/jsonschema" orderedmap "github.com/wk8/go-ordered-map/v2" ) // DataType is the type of the parameter. // It must be one of the following values: "object", "number", "integer", "string", "array", "null", "boolean", which is the same as the type of the parameter in JSONSchema. type DataType string // Supported JSONSchema data types for tool parameters. const ( Object DataType = "object" Number DataType = "number" Integer DataType = "integer" String DataType = "string" Array DataType = "array" Null DataType = "null" Boolean DataType = "boolean" ) // ToolChoice controls how the model uses the tools provided to it. // Pass as part of the model option via [model.WithToolChoice]. type ToolChoice string const ( // ToolChoiceForbidden instructs the model not to call any tools, even if // tools are bound. The model responds with a plain text message instead. // Corresponds to "none" in OpenAI Chat Completion. ToolChoiceForbidden ToolChoice = "forbidden" // ToolChoiceAllowed lets the model decide: it may generate a plain message // or call one or more tools. This is the default when tools are provided. // Corresponds to "auto" in OpenAI Chat Completion. ToolChoiceAllowed ToolChoice = "allowed" // ToolChoiceForced requires the model to call at least one tool. Use this // when you want to guarantee structured output via tool calling. // Corresponds to "required" in OpenAI Chat Completion. ToolChoiceForced ToolChoice = "forced" ) // ToolInfo describes a tool that can be passed to a ChatModel via // [ToolCallingChatModel.WithTools] or [ChatModel.BindTools]. // // Name should be concise and unique within the tool set. Desc should explain // when and why to use the tool; few-shot examples in Desc significantly improve // model accuracy. ParamsOneOf may be nil for tools that take no arguments. type ToolInfo struct { // The unique name of the tool that clearly communicates its purpose. Name string // Used to tell the model how/when/why to use the tool. // You can provide few-shot examples as a part of the description. Desc string // Extra is the extra information for the tool. Extra map[string]any // The parameters the functions accepts (different models may require different parameter types). // can be described in two ways: // - use params: schema.NewParamsOneOfByParams(params) // - use jsonschema: schema.NewParamsOneOfByJSONSchema(jsonschema) // If is nil, signals that the tool does not need any input parameter *ParamsOneOf } // ParameterInfo is the information of a parameter. // It is used to describe the parameters of a tool. type ParameterInfo struct { // The type of the parameter. Type DataType // The element type of the parameter, only for array. ElemInfo *ParameterInfo // The sub parameters of the parameter, only for object. SubParams map[string]*ParameterInfo // The description of the parameter. Desc string // The enum values of the parameter, only for string. Enum []string // Whether the parameter is required. Required bool } // ParamsOneOf holds a tool's parameter schema using exactly one of two // representations. Choose the one that best fits your needs: // // 1. [NewParamsOneOfByParams] — lightweight: describe parameters as a // map[string]*[ParameterInfo]. Covers the most common cases (scalars, // arrays, nested objects, enums, required flags). // // 2. [NewParamsOneOfByJSONSchema] — powerful: supply a full // *jsonschema.Schema (JSON Schema 2020-12). Required when you need // features not expressible via ParameterInfo, such as anyOf, oneOf, or // $defs references. [utils.InferTool] generates this form automatically // from Go struct tags. // // You must use exactly one constructor — setting both fields is invalid. // If ParamsOneOf is nil, the tool takes no input parameters. type ParamsOneOf struct { // use NewParamsOneOfByParams to set this field params map[string]*ParameterInfo jsonschema *jsonschema.Schema } // NewParamsOneOfByParams creates a ParamsOneOf with map[string]*ParameterInfo. func NewParamsOneOfByParams(params map[string]*ParameterInfo) *ParamsOneOf { return &ParamsOneOf{ params: params, } } // NewParamsOneOfByJSONSchema creates a ParamsOneOf with *jsonschema.Schema. func NewParamsOneOfByJSONSchema(s *jsonschema.Schema) *ParamsOneOf { return &ParamsOneOf{ jsonschema: s, } } // ToJSONSchema parses ParamsOneOf, converts the parameter description that user actually provides, into the format ready to be passed to Model. func (p *ParamsOneOf) ToJSONSchema() (*jsonschema.Schema, error) { if p == nil { return nil, nil } if p.params != nil { sc := &jsonschema.Schema{ Properties: orderedmap.New[string, *jsonschema.Schema](), Type: string(Object), Required: make([]string, 0, len(p.params)), } keys := make([]string, 0, len(p.params)) for k := range p.params { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { v := p.params[k] sc.Properties.Set(k, paramInfoToJSONSchema(v)) if v.Required { sc.Required = append(sc.Required, k) } } return sc, nil } return p.jsonschema, nil } func paramInfoToJSONSchema(paramInfo *ParameterInfo) *jsonschema.Schema { js := &jsonschema.Schema{ Type: string(paramInfo.Type), Description: paramInfo.Desc, } if len(paramInfo.Enum) > 0 { js.Enum = make([]any, len(paramInfo.Enum)) for i, enum := range paramInfo.Enum { js.Enum[i] = enum } } if paramInfo.ElemInfo != nil { js.Items = paramInfoToJSONSchema(paramInfo.ElemInfo) } if len(paramInfo.SubParams) > 0 { required := make([]string, 0, len(paramInfo.SubParams)) js.Properties = orderedmap.New[string, *jsonschema.Schema]() keys := make([]string, 0, len(paramInfo.SubParams)) for k := range paramInfo.SubParams { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { v := paramInfo.SubParams[k] item := paramInfoToJSONSchema(v) js.Properties.Set(k, item) if v.Required { required = append(required, k) } } js.Required = required } return js } ================================================ FILE: schema/tool_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package schema import ( "encoding/json" "testing" "github.com/eino-contrib/jsonschema" "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" ) func TestParamsOneOfToJSONSchema(t *testing.T) { convey.Convey("ParamsOneOfToJSONSchema", t, func() { var ( oneOf ParamsOneOf converted any err error ) convey.Convey("user provides JSON schema directly, use what the user provides", func() { oneOf.jsonschema = &jsonschema.Schema{ Type: "string", Description: "this is the only argument", } converted, err = oneOf.ToJSONSchema() convey.So(err, convey.ShouldBeNil) convey.So(converted, convey.ShouldResemble, oneOf.jsonschema) }) convey.Convey("user provides map[string]ParameterInfo, converts to json schema", func() { oneOf.params = map[string]*ParameterInfo{ "arg1": { Type: String, Desc: "this is the first argument", Required: true, Enum: []string{"1", "2"}, }, "arg2": { Type: Object, Desc: "this is the second argument", SubParams: map[string]*ParameterInfo{ "sub_arg1": { Type: String, Desc: "this is the sub argument", Required: true, Enum: []string{"1", "2"}, }, "sub_arg2": { Type: String, Desc: "this is the sub argument 2", }, }, Required: true, }, "arg3": { Type: Array, Desc: "this is the third argument", ElemInfo: &ParameterInfo{ Type: String, Desc: "this is the element of the third argument", Required: true, Enum: []string{"1", "2"}, }, Required: true, }, } converted, err = oneOf.ToJSONSchema() convey.So(err, convey.ShouldBeNil) }) convey.Convey("user provides map[string]ParameterInfo, converts to json schema in order", func() { params := &ParamsOneOf{ params: map[string]*ParameterInfo{ "c": { Type: "string", }, "a": { Type: "object", SubParams: map[string]*ParameterInfo{ "z": { Type: "number", }, "y": { Type: "string", }, }, }, "b": { Type: "array", ElemInfo: &ParameterInfo{ Type: "object", SubParams: map[string]*ParameterInfo{ "p": { Type: "integer", }, "o": { Type: "boolean", }, }, }, }, }, } schema1, err := params.ToJSONSchema() assert.NoError(t, err) json1, err := json.Marshal(schema1) assert.NoError(t, err) schema2, err := params.ToJSONSchema() assert.NoError(t, err) json2, err := json.Marshal(schema2) assert.NoError(t, err) assert.Equal(t, string(json1), string(json2)) }) }) } ================================================ FILE: scripts/dev_setup.sh ================================================ #!/usr/bin/env bash # dev_setup.sh — set up a local multi-module workspace for eino development. # # BACKGROUND # eino, eino-ext, and eino-examples live in separate GitHub repositories to # keep their Go modules, versioning, and maintenance independent. However, # working across them is inconvenient: editors and AI coding tools lack # cross-repo type information and can't navigate between them. # # This script brings all three repos together locally so that a single # go.work file provides full cross-module LSP (go-to-definition, type # inference, autocomplete) across all ~83 modules — without touching any # remote repository. # # WHAT IT DOES # 1. Clones eino-ext → ext/ # 2. Clones eino-examples → examples/ # 3. Registers ext/ and examples/ in .git/info/exclude so eino's git # never sees them (local-only, never committed) # 4. Creates go.work at the repo root covering eino + all modules in # ext/ and examples/ (go.work is already in .gitignore) # # RESULTING LAYOUT # eino/ ← you are here (github.com/cloudwego/eino) # eino/ext/ ← github.com/cloudwego/eino-ext (full git repo) # eino/examples/ ← github.com/cloudwego/eino-examples (full git repo) # eino/go.work ← wires all modules together (gitignored) # # WORKING ACROSS REPOS # Each subdirectory is a full independent git repo tracking its own remote. # To contribute to eino-ext or eino-examples, work inside that directory: # # cd ext # git checkout -b feat/my-feature # # make changes — editor has full cross-repo type info via go.work # git commit -m "feat: ..." # git push origin feat/my-feature # pushes to cloudwego/eino-ext # # KEEPING REPOS UP TO DATE # git -C ext pull # git -C examples pull # # USAGE # bash scripts/dev_setup.sh # first-time setup # bash scripts/dev_setup.sh --reset # re-clone everything from scratch set -euo pipefail REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$REPO_ROOT" EXT_DIR="ext" EXAMPLES_DIR="examples" EINO_EXT_REPO="https://github.com/cloudwego/eino-ext" EINO_EXAMPLES_REPO="https://github.com/cloudwego/eino-examples" # Parse flags RESET=false for arg in "$@"; do case $arg in --reset) RESET=true ;; esac done echo "==> Setting up eino dev workspace in: $REPO_ROOT" # --reset: remove existing dirs if [ "$RESET" = true ]; then echo "==> --reset: removing $EXT_DIR/ and $EXAMPLES_DIR/" rm -rf "$EXT_DIR" "$EXAMPLES_DIR" go.work go.work.sum fi # Clone repos if not already present if [ ! -d "$EXT_DIR/.git" ]; then echo "==> Cloning eino-ext into $EXT_DIR/" git clone "$EINO_EXT_REPO" "$EXT_DIR" else echo "==> $EXT_DIR/ already exists, skipping clone" fi if [ ! -d "$EXAMPLES_DIR/.git" ]; then echo "==> Cloning eino-examples into $EXAMPLES_DIR/" git clone "$EINO_EXAMPLES_REPO" "$EXAMPLES_DIR" else echo "==> $EXAMPLES_DIR/ already exists, skipping clone" fi # Exclude dirs from eino's git tracking (local only, not committed) EXCLUDE_FILE=".git/info/exclude" add_exclude() { local entry="$1" if ! grep -qxF "$entry" "$EXCLUDE_FILE" 2>/dev/null; then echo "$entry" >> "$EXCLUDE_FILE" echo "==> Added '$entry' to $EXCLUDE_FILE" fi } add_exclude "$EXT_DIR/" add_exclude "$EXAMPLES_DIR/" # Build go.work covering eino root + every go.mod found in ext/ and examples/ if [ ! -f "go.work" ]; then echo "==> Creating go.work" go work init . # Collect all module directories (directories containing a go.mod) while IFS= read -r modfile; do dir="$(dirname "$modfile")" go work use "$dir" done < <(find "$EXT_DIR" "$EXAMPLES_DIR" -name "go.mod" | sort) echo "==> go.work created with $(grep -c '^\s\+\.' go.work || true) module(s)" else echo "==> go.work already exists, skipping (use --reset to recreate)" fi echo "" echo "Done. Your workspace includes:" echo " . — github.com/cloudwego/eino" echo " $EXT_DIR/ — github.com/cloudwego/eino-ext ($(find "$EXT_DIR" -name "go.mod" | wc -l | tr -d ' ') modules)" echo " $EXAMPLES_DIR/ — github.com/cloudwego/eino-examples ($(find "$EXAMPLES_DIR" -name "go.mod" | wc -l | tr -d ' ') modules)" echo "" echo "Run 'go build ./...' or open this directory in your editor." ================================================ FILE: scripts/eino_setup.sh ================================================ #!/usr/bin/env bash # eino_setup.sh — fetch eino framework source into your project for AI-assisted development. # # BACKGROUND # When building applications with eino, your AI coding assistant (Claude Code, # Cursor, Copilot, etc.) only sees your code. It cannot navigate into eino's # source to understand how components work, what patterns are idiomatic, or # how to wire things together correctly. # # This script clones eino, eino-ext, and eino-examples into a _eino/ directory # inside your project. Your AI assistant can then browse the actual source, # examples, and extensions — giving it full context to help you build correctly. # # WHAT IT DOES # 1. Clones eino → _eino/eino/ # 2. Clones eino-ext → _eino/eino-ext/ # 3. Clones eino-examples → _eino/eino-examples/ # 4. Adds _eino/ to .gitignore (read-only reference, never committed) # 5. Writes a _eino/README.md explaining the directory to future readers # # RESULTING LAYOUT # your-project/ # ├── _eino/ # │ ├── eino/ ← github.com/cloudwego/eino (core framework) # │ ├── eino-ext/ ← github.com/cloudwego/eino-ext (components & integrations) # │ └── eino-examples/ ← github.com/cloudwego/eino-examples (patterns & recipes) # └── ... your code # # NOTE: _eino/ is read-only reference material. Do not edit files inside it. # Your go.mod is unchanged — eino remains a normal dependency. # # KEEPING UP TO DATE # bash eino_setup.sh --update # pull latest on all three repos # # USAGE # bash eino_setup.sh # first-time setup # bash eino_setup.sh --reset # re-clone everything from scratch # bash eino_setup.sh --update # pull latest without re-cloning # # SYSTEM PROMPT # After running this script, add the following to your AI assistant's project # instructions (CLAUDE.md, .cursorrules, .github/copilot-instructions.md, etc.): # # --- # ## eino Framework Reference # # This project uses the eino framework (github.com/cloudwego/eino). # The full framework source is available locally in `_eino/`: # # - `_eino/eino/` — core framework (components, graph, compose, callbacks) # - `_eino/eino-ext/` — official components and integrations (models, tools, retrievers, etc.) # - `_eino/eino-examples/` — working examples and patterns # # When answering questions about eino APIs, component wiring, graph construction, # callbacks, or any eino-specific patterns: explore `_eino/` first. # Prefer examples from `_eino/eino-examples/` as the canonical reference for # idiomatic usage. # --- set -euo pipefail PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$PROJECT_ROOT" EINO_DIR="_eino" EINO_REPO="https://github.com/cloudwego/eino" EINO_EXT_REPO="https://github.com/cloudwego/eino-ext" EINO_EXAMPLES_REPO="https://github.com/cloudwego/eino-examples" # Parse flags RESET=false UPDATE=false for arg in "$@"; do case $arg in --reset) RESET=true ;; --update) UPDATE=true ;; esac done echo "==> eino setup in: $PROJECT_ROOT" # --reset: remove and re-clone if [ "$RESET" = true ]; then echo "==> --reset: removing $EINO_DIR/" rm -rf "$EINO_DIR" fi # --update: pull latest on existing clones if [ "$UPDATE" = true ]; then for repo in eino eino-ext eino-examples; do dir="$EINO_DIR/$repo" if [ -d "$dir/.git" ]; then echo "==> Updating $dir/" git -C "$dir" pull --ff-only else echo "==> $dir/ not found, skipping update (run without --update to clone)" fi done echo "" echo "Done. Run 'bash eino_setup.sh' to clone any missing repos." exit 0 fi mkdir -p "$EINO_DIR" # Clone repos (shallow — we only need source to read, not full history) clone_if_missing() { local repo_url="$1" local dest="$2" if [ ! -d "$dest/.git" ]; then echo "==> Cloning $(basename "$dest")/" git clone --depth=1 "$repo_url" "$dest" else echo "==> $dest/ already exists, skipping clone" fi } clone_if_missing "$EINO_REPO" "$EINO_DIR/eino" clone_if_missing "$EINO_EXT_REPO" "$EINO_DIR/eino-ext" clone_if_missing "$EINO_EXAMPLES_REPO" "$EINO_DIR/eino-examples" # Add _eino/ to .gitignore GITIGNORE=".gitignore" if ! grep -qxF "$EINO_DIR/" "$GITIGNORE" 2>/dev/null; then echo "" >> "$GITIGNORE" echo "# eino framework source (AI coding reference — see eino_setup.sh)" >> "$GITIGNORE" echo "$EINO_DIR/" >> "$GITIGNORE" echo "==> Added '$EINO_DIR/' to $GITIGNORE" fi # Write a README so the directory is self-explanatory cat > "$EINO_DIR/README.md" <<'EOF' # _eino — eino framework source reference This directory contains read-only clones of the eino framework repositories, checked out for use by AI coding assistants (Claude Code, Cursor, Copilot, etc.). | Directory | Repository | Purpose | |----------------|-----------------------------------------|--------------------------------| | `eino/` | github.com/cloudwego/eino | Core framework source | | `eino-ext/` | github.com/cloudwego/eino-ext | Components and integrations | | `eino-examples/` | github.com/cloudwego/eino-examples | Patterns, recipes, and samples | **Do not edit files here.** This directory is in `.gitignore` and is never committed. To update to the latest: bash eino_setup.sh --update To re-clone from scratch: bash eino_setup.sh --reset EOF echo "" echo "Done. Your AI assistant now has full eino context in $EINO_DIR/:" echo " $EINO_DIR/eino/ — core framework ($(find "$EINO_DIR/eino" -name "*.go" | wc -l | tr -d ' ') .go files)" echo " $EINO_DIR/eino-ext/ — components & integrations ($(find "$EINO_DIR/eino-ext" -name "*.go" | wc -l | tr -d ' ') .go files)" echo " $EINO_DIR/eino-examples/ — patterns & recipes ($(find "$EINO_DIR/eino-examples" -name "*.go" | wc -l | tr -d ' ') .go files)" echo "" echo "Add the following to your AI assistant's system prompt or project instructions" echo "(e.g. CLAUDE.md, .cursorrules, .github/copilot-instructions.md):" echo "" echo "---" cat <<'PROMPT' ## eino Framework Reference This project uses the eino framework (github.com/cloudwego/eino). The full framework source is available locally in `_eino/`: - `_eino/eino/` — core framework (components, graph, compose, callbacks) - `_eino/eino-ext/` — official components and integrations (models, tools, retrievers, etc.) - `_eino/eino-examples/` — working examples and patterns When answering questions about eino APIs, component wiring, graph construction, callbacks, or any eino-specific patterns: explore `_eino/` first. Prefer examples from `_eino/eino-examples/` as the canonical reference for idiomatic usage. PROMPT echo "---" ================================================ FILE: utils/callbacks/template.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package callbacks provides ready-to-use callback handler templates for components. package callbacks import ( "context" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) // NewHandlerHelper creates a new component template handler builder. // This builder can be used to configure and build a component template handler, // which can handle callback events for different components with its own struct definition, // and fallbackTemplate can be used to handle scenarios where none of the cases are hit as a fallback. func NewHandlerHelper() *HandlerHelper { return &HandlerHelper{ composeTemplates: map[components.Component]callbacks.Handler{}, } } // HandlerHelper is a builder for creating a callbacks.Handler with specific handlers for different component types. // create a handler with callbacks.NewHandlerHelper(). // eg. // // helper := template.NewHandlerHelper(). // ChatModel(&model.IndexerCallbackHandler{}). // Prompt(&prompt.IndexerCallbackHandler{}). // Handler() // // then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) type HandlerHelper struct { promptHandler *PromptCallbackHandler chatModelHandler *ModelCallbackHandler embeddingHandler *EmbeddingCallbackHandler indexerHandler *IndexerCallbackHandler retrieverHandler *RetrieverCallbackHandler loaderHandler *LoaderCallbackHandler transformerHandler *TransformerCallbackHandler toolHandler *ToolCallbackHandler toolsNodeHandler *ToolsNodeCallbackHandlers agentHandler *AgentCallbackHandler composeTemplates map[components.Component]callbacks.Handler } // Handler returns the callbacks.Handler created by HandlerHelper. func (c *HandlerHelper) Handler() callbacks.Handler { return &handlerTemplate{c} } // Prompt sets the prompt handler for the handler helper, which will be called when the prompt component is executed. func (c *HandlerHelper) Prompt(handler *PromptCallbackHandler) *HandlerHelper { c.promptHandler = handler return c } // ChatModel sets the chat model handler for the handler helper, which will be called when the chat model component is executed. func (c *HandlerHelper) ChatModel(handler *ModelCallbackHandler) *HandlerHelper { c.chatModelHandler = handler return c } // Embedding sets the embedding handler for the handler helper, which will be called when the embedding component is executed. func (c *HandlerHelper) Embedding(handler *EmbeddingCallbackHandler) *HandlerHelper { c.embeddingHandler = handler return c } // Indexer sets the indexer handler for the handler helper, which will be called when the indexer component is executed. func (c *HandlerHelper) Indexer(handler *IndexerCallbackHandler) *HandlerHelper { c.indexerHandler = handler return c } // Retriever sets the retriever handler for the handler helper, which will be called when the retriever component is executed. func (c *HandlerHelper) Retriever(handler *RetrieverCallbackHandler) *HandlerHelper { c.retrieverHandler = handler return c } // Loader sets the loader handler for the handler helper, which will be called when the loader component is executed. func (c *HandlerHelper) Loader(handler *LoaderCallbackHandler) *HandlerHelper { c.loaderHandler = handler return c } // Transformer sets the transformer handler for the handler helper, which will be called when the transformer component is executed. func (c *HandlerHelper) Transformer(handler *TransformerCallbackHandler) *HandlerHelper { c.transformerHandler = handler return c } // Tool sets the tool handler for the handler helper, which will be called when the tool component is executed. func (c *HandlerHelper) Tool(handler *ToolCallbackHandler) *HandlerHelper { c.toolHandler = handler return c } // ToolsNode sets the tools node handler for the handler helper, which will be called when the tools node is executed. func (c *HandlerHelper) ToolsNode(handler *ToolsNodeCallbackHandlers) *HandlerHelper { c.toolsNodeHandler = handler return c } // Agent sets the agent handler for the handler helper, which will be called when the agent is executed. func (c *HandlerHelper) Agent(handler *AgentCallbackHandler) *HandlerHelper { c.agentHandler = handler return c } // Graph sets the graph handler for the handler helper, which will be called when the graph is executed. func (c *HandlerHelper) Graph(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfGraph] = handler return c } // Chain sets the chain handler for the handler helper, which will be called when the chain is executed. func (c *HandlerHelper) Chain(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfChain] = handler return c } // Lambda sets the lambda handler for the handler helper, which will be called when the lambda is executed. func (c *HandlerHelper) Lambda(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfLambda] = handler return c } type handlerTemplate struct { *HandlerHelper } // OnStart is the callback function for the start event of a component. // implement the callbacks Handler interface. func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: return c.indexerHandler.OnStart(ctx, info, indexer.ConvCallbackInput(input)) case components.ComponentOfRetriever: return c.retrieverHandler.OnStart(ctx, info, retriever.ConvCallbackInput(input)) case components.ComponentOfLoader: return c.loaderHandler.OnStart(ctx, info, document.ConvLoaderCallbackInput(input)) case components.ComponentOfTransformer: return c.transformerHandler.OnStart(ctx, info, document.ConvTransformerCallbackInput(input)) case components.ComponentOfTool: return c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnStart(ctx, info, convToolsNodeCallbackInput(input)) case adk.ComponentOfAgent: return c.agentHandler.OnStart(ctx, info, adk.ConvAgentCallbackInput(input)) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: return c.composeTemplates[info.Component].OnStart(ctx, info, input) default: return ctx } } // OnEnd is the callback function for the end event of a component. // implement the callbacks Handler interface. func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: return c.indexerHandler.OnEnd(ctx, info, indexer.ConvCallbackOutput(output)) case components.ComponentOfRetriever: return c.retrieverHandler.OnEnd(ctx, info, retriever.ConvCallbackOutput(output)) case components.ComponentOfLoader: return c.loaderHandler.OnEnd(ctx, info, document.ConvLoaderCallbackOutput(output)) case components.ComponentOfTransformer: return c.transformerHandler.OnEnd(ctx, info, document.ConvTransformerCallbackOutput(output)) case components.ComponentOfTool: return c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnEnd(ctx, info, convToolsNodeCallbackOutput(output)) case adk.ComponentOfAgent: return c.agentHandler.OnEnd(ctx, info, adk.ConvAgentCallbackOutput(output)) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: return c.composeTemplates[info.Component].OnEnd(ctx, info, output) default: return ctx } } // OnError is the callback function for the error event of a component. // implement the callbacks Handler interface. func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnError(ctx, info, err) case components.ComponentOfChatModel: return c.chatModelHandler.OnError(ctx, info, err) case components.ComponentOfEmbedding: return c.embeddingHandler.OnError(ctx, info, err) case components.ComponentOfIndexer: return c.indexerHandler.OnError(ctx, info, err) case components.ComponentOfRetriever: return c.retrieverHandler.OnError(ctx, info, err) case components.ComponentOfLoader: return c.loaderHandler.OnError(ctx, info, err) case components.ComponentOfTransformer: return c.transformerHandler.OnError(ctx, info, err) case components.ComponentOfTool: return c.toolHandler.OnError(ctx, info, err) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnError(ctx, info, err) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: return c.composeTemplates[info.Component].OnError(ctx, info, err) default: return ctx } } // OnStartWithStreamInput is the callback function for the start event of a component with stream input. // implement the callbacks Handler interface. func (c *handlerTemplate) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { switch info.Component { // currently no components.Component receive stream as input case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: return c.composeTemplates[info.Component].OnStartWithStreamInput(ctx, info, input) default: return ctx } } // OnEndWithStreamOutput is the callback function for the end event of a component with stream output. // implement the callbacks Handler interface. func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { switch info.Component { case components.ComponentOfChatModel: return c.chatModelHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { return model.ConvCallbackOutput(item), nil })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*tool.CallbackOutput, error) { return tool.ConvCallbackOutput(item), nil })) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.Message, error) { return convToolsNodeCallbackOutput(item), nil })) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: return c.composeTemplates[info.Component].OnEndWithStreamOutput(ctx, info, output) default: return ctx } } // Needed checks if the callback handler is needed for the given timing. func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { if info == nil { return false } switch info.Component { case components.ComponentOfChatModel: if c.chatModelHandler != nil && c.chatModelHandler.Needed(ctx, info, timing) { return true } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.Needed(ctx, info, timing) { return true } case components.ComponentOfIndexer: if c.indexerHandler != nil && c.indexerHandler.Needed(ctx, info, timing) { return true } case components.ComponentOfLoader: if c.loaderHandler != nil && c.loaderHandler.Needed(ctx, info, timing) { return true } case components.ComponentOfPrompt: if c.promptHandler != nil && c.promptHandler.Needed(ctx, info, timing) { return true } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.Needed(ctx, info, timing) { return true } case components.ComponentOfTool: if c.toolHandler != nil && c.toolHandler.Needed(ctx, info, timing) { return true } case components.ComponentOfTransformer: if c.transformerHandler != nil && c.transformerHandler.Needed(ctx, info, timing) { return true } case compose.ComponentOfToolsNode: if c.toolsNodeHandler != nil && c.toolsNodeHandler.Needed(ctx, info, timing) { return true } case adk.ComponentOfAgent: if c.agentHandler != nil && c.agentHandler.Needed(ctx, info, timing) { return true } case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: handler := c.composeTemplates[info.Component] if handler != nil { checker, ok := handler.(callbacks.TimingChecker) if !ok || checker.Needed(ctx, info, timing) { return true } } default: return false } return false } // LoaderCallbackHandler is the handler for the loader callback. type LoaderCallbackHandler struct { OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.LoaderCallbackOutput) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *LoaderCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } // TransformerCallbackHandler is the handler for the transformer callback. type TransformerCallbackHandler struct { OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.TransformerCallbackOutput) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *TransformerCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } // EmbeddingCallbackHandler is the handler for the embedding callback. type EmbeddingCallbackHandler struct { OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *embedding.CallbackInput) context.Context OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *embedding.CallbackOutput) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *EmbeddingCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } // IndexerCallbackHandler is the handler for the indexer callback. type IndexerCallbackHandler struct { OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *indexer.CallbackInput) context.Context OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *indexer.CallbackOutput) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *IndexerCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } // ModelCallbackHandler is the handler for the model callback. type ModelCallbackHandler struct { OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *ModelCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnError: return ch.OnError != nil case callbacks.TimingOnEndWithStreamOutput: return ch.OnEndWithStreamOutput != nil default: return false } } // PromptCallbackHandler is the handler for the callback. type PromptCallbackHandler struct { // OnStart is the callback function for the start of the callback. OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context // OnEnd is the callback function for the end of the callback. OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context // OnError is the callback function for the error of the callback. OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *PromptCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } // RetrieverCallbackHandler is the handler for the retriever callback. type RetrieverCallbackHandler struct { // OnStart is the callback function for the start of the retriever. OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *retriever.CallbackInput) context.Context // OnEnd is the callback function for the end of the retriever. OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *retriever.CallbackOutput) context.Context // OnError is the callback function for the error of the retriever. OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *RetrieverCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } // ToolCallbackHandler is the handler for the tool callback. type ToolCallbackHandler struct { OnStart func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput) context.Context OnEnd func(ctx context.Context, info *callbacks.RunInfo, output *tool.CallbackOutput) context.Context OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*tool.CallbackOutput]) context.Context OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context } // Needed checks if the callback handler is needed for the given timing. func (ch *ToolCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnEndWithStreamOutput: return ch.OnEndWithStreamOutput != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } // ToolsNodeCallbackHandlers defines optional callbacks for the Tools node // lifecycle events. type ToolsNodeCallbackHandlers struct { OnStart func(ctx context.Context, info *callbacks.RunInfo, input *schema.Message) context.Context OnEnd func(ctx context.Context, info *callbacks.RunInfo, input []*schema.Message) context.Context OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.Message]) context.Context OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context } // Needed reports whether a handler is registered for the given timing. func (ch *ToolsNodeCallbackHandlers) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil case callbacks.TimingOnEndWithStreamOutput: return ch.OnEndWithStreamOutput != nil case callbacks.TimingOnError: return ch.OnError != nil default: return false } } func convToolsNodeCallbackInput(src callbacks.CallbackInput) *schema.Message { switch t := src.(type) { case *schema.Message: return t default: return nil } } func convToolsNodeCallbackOutput(src callbacks.CallbackInput) []*schema.Message { switch t := src.(type) { case []*schema.Message: return t default: return nil } } type AgentCallbackHandler struct { OnStart func(ctx context.Context, info *callbacks.RunInfo, input *adk.AgentCallbackInput) context.Context OnEnd func(ctx context.Context, info *callbacks.RunInfo, output *adk.AgentCallbackOutput) context.Context } func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { switch timing { case callbacks.TimingOnStart: return ch.OnStart != nil case callbacks.TimingOnEnd: return ch.OnEnd != nil default: return false } } ================================================ FILE: utils/callbacks/template_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package callbacks import ( "context" "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) func TestNewComponentTemplate(t *testing.T) { t.Run("TestNewComponentTemplate", func(t *testing.T) { cnt := 0 tpl := NewHandlerHelper() tpl.ChatModel(&ModelCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { cnt++ return ctx }, OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { output.Close() cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }}). Embedding(&EmbeddingCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *embedding.CallbackOutput) context.Context { cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }, }). Prompt(&PromptCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }, }). Retriever(&RetrieverCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *retriever.CallbackOutput) context.Context { cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }, }). Tool(&ToolCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *tool.CallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *tool.CallbackOutput) context.Context { cnt++ return ctx }, OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*tool.CallbackOutput]) context.Context { cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }, }). Lambda(callbacks.NewHandlerBuilder(). OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { cnt++ return ctx }). OnStartWithStreamInputFn(func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { input.Close() cnt++ return ctx }). OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { cnt++ return ctx }). OnEndWithStreamOutputFn(func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { output.Close() cnt++ return ctx }). OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }).Build()). Handler() types := []components.Component{ components.ComponentOfPrompt, components.ComponentOfChatModel, components.ComponentOfEmbedding, components.ComponentOfRetriever, components.ComponentOfTool, compose.ComponentOfLambda, } handler := tpl.Handler() ctx := context.Background() for _, typ := range types { handler.OnStart(ctx, &callbacks.RunInfo{Component: typ}, nil) handler.OnEnd(ctx, &callbacks.RunInfo{Component: typ}, nil) handler.OnError(ctx, &callbacks.RunInfo{Component: typ}, fmt.Errorf("mock err")) sir, siw := schema.Pipe[callbacks.CallbackInput](1) siw.Close() handler.OnStartWithStreamInput(ctx, &callbacks.RunInfo{Component: typ}, sir) sor, sow := schema.Pipe[callbacks.CallbackOutput](1) sow.Close() handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: typ}, sor) } assert.Equal(t, 22, cnt) ctx = context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) callbacks.OnStart[any](ctx, nil) assert.Equal(t, 22, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) ctx = callbacks.OnStart[any](ctx, nil) assert.Equal(t, 23, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnEnd[any](ctx, nil) assert.Equal(t, 23, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) callbacks.OnError(ctx, nil) assert.Equal(t, 24, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnStart[any](ctx, nil) assert.Equal(t, 24, cnt) tpl.Transformer(&TransformerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.TransformerCallbackOutput) context.Context { cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }, }).Indexer(&IndexerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *indexer.CallbackOutput) context.Context { cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }, }).Loader(&LoaderCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.LoaderCallbackOutput) context.Context { cnt++ return ctx }, OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx }, }).ToolsNode(&ToolsNodeCallbackHandlers{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *schema.Message) context.Context { cnt++ return ctx }, OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[[]*schema.Message]) context.Context { cnt++ if output == nil { return ctx } for { _, err := output.Recv() if err != nil { return ctx } } }, }) handler = tpl.Handler() ctx = context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) ctx = callbacks.OnStart[any](ctx, nil) assert.Equal(t, 25, cnt) callbacks.OnEnd[any](ctx, nil) assert.Equal(t, 26, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnEnd[any](ctx, nil) assert.Equal(t, 27, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}) callbacks.OnStart[any](ctx, nil) assert.Equal(t, 28, cnt) sr, sw := schema.Pipe[any](0) sw.Close() callbacks.OnEndWithStreamOutput[any](ctx, sr) assert.Equal(t, 29, cnt) sr1, sw1 := schema.Pipe[[]*schema.Message](1) sw1.Send([]*schema.Message{{}}, nil) sw1.Close() callbacks.OnEndWithStreamOutput[[]*schema.Message](ctx, sr1) assert.Equal(t, 30, cnt) callbacks.OnError(ctx, nil) assert.Equal(t, 30, cnt) ctx = callbacks.ReuseHandlers(ctx, nil) callbacks.OnStart[any](ctx, nil) assert.Equal(t, 30, cnt) }) } func TestAgentCallbackHandler(t *testing.T) { t.Run("Needed returns correct values", func(t *testing.T) { handler := &AgentCallbackHandler{ OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.AgentCallbackInput) context.Context { return ctx }, } ctx := context.Background() info := &callbacks.RunInfo{Component: adk.ComponentOfAgent} assert.True(t, handler.Needed(ctx, info, callbacks.TimingOnStart)) assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnEnd)) }) t.Run("Needed with OnEnd set", func(t *testing.T) { handler := &AgentCallbackHandler{ OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output *adk.AgentCallbackOutput) context.Context { return ctx }, } ctx := context.Background() info := &callbacks.RunInfo{Component: adk.ComponentOfAgent} assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnStart)) assert.True(t, handler.Needed(ctx, info, callbacks.TimingOnEnd)) }) t.Run("Needed with nil handlers", func(t *testing.T) { handler := &AgentCallbackHandler{} ctx := context.Background() info := &callbacks.RunInfo{Component: adk.ComponentOfAgent} assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnStart)) assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnEnd)) }) } func TestHandlerHelperWithAgent(t *testing.T) { t.Run("Agent method sets handler correctly", func(t *testing.T) { cnt := 0 tpl := NewHandlerHelper() tpl.Agent(&AgentCallbackHandler{ OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.AgentCallbackInput) context.Context { cnt++ return ctx }, OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output *adk.AgentCallbackOutput) context.Context { cnt++ return ctx }, }) handler := tpl.Handler() ctx := context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: adk.ComponentOfAgent}, handler) ctx = callbacks.OnStart[any](ctx, nil) assert.Equal(t, 1, cnt) callbacks.OnEnd[any](ctx, nil) assert.Equal(t, 2, cnt) }) } func TestHandlerTemplateWithAgentComponent(t *testing.T) { t.Run("OnStart routes to agent handler", func(t *testing.T) { called := false tpl := NewHandlerHelper() tpl.Agent(&AgentCallbackHandler{ OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.AgentCallbackInput) context.Context { called = true return ctx }, }) handler := tpl.Handler() ctx := context.Background() info := &callbacks.RunInfo{Component: adk.ComponentOfAgent, Name: "TestAgent"} handler.OnStart(ctx, info, &adk.AgentCallbackInput{}) assert.True(t, called) }) t.Run("OnEnd routes to agent handler", func(t *testing.T) { called := false tpl := NewHandlerHelper() tpl.Agent(&AgentCallbackHandler{ OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output *adk.AgentCallbackOutput) context.Context { called = true return ctx }, }) handler := tpl.Handler() ctx := context.Background() info := &callbacks.RunInfo{Component: adk.ComponentOfAgent, Name: "TestAgent"} handler.OnEnd(ctx, info, &adk.AgentCallbackOutput{}) assert.True(t, called) }) t.Run("Needed returns true for agent component", func(t *testing.T) { tpl := NewHandlerHelper() tpl.Agent(&AgentCallbackHandler{ OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.AgentCallbackInput) context.Context { return ctx }, }) handler := tpl.Handler() ctx := context.Background() info := &callbacks.RunInfo{Component: adk.ComponentOfAgent} checker, ok := handler.(callbacks.TimingChecker) assert.True(t, ok, "handler should implement TimingChecker") assert.True(t, checker.Needed(ctx, info, callbacks.TimingOnStart)) }) }